diff --git a/.clang-tidy b/.clang-tidy index 5c2b78168747787b506b0cd6e72af595eb270670..3815c654fe91969e18403b9e74c8e01971ae804b 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,3 +1,3 @@ CheckOptions: - key: bugprone-reserved-identifier.AllowedIdentifiers - value: '__HIP_PLATFORM_HCC__;__HIP_ROCclr__' + value: '__HIP_PLATFORM_HCC__;__HIP_PLATFORM_AMD__;__HIP_ROCclr__' diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..e4d0d47a2e1c765b969a0854296ef6b84af43bbe --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,7 @@ +* @zjing14 @asroy @junliume @illsilin @carlushuang @aosewski +# Documentation files +docs/* @saadrahim @LisaDelaney +*.md @saadrahim @LisaDelaney +*.rst @saadrahim @LisaDelaney +# Header directory +library/include/* @saadrahim @LisaDelaney diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000000000000000000000000000000000..0086358db1eb971c0cfa8739c27518bbc18a5ff4 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: true diff --git a/.github/ISSUE_TEMPLATE/issue_report.yml b/.github/ISSUE_TEMPLATE/issue_report.yml new file mode 100644 index 0000000000000000000000000000000000000000..ef6e6faa1b71b07865f1e436c7a0354d2a0f8c18 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/issue_report.yml @@ -0,0 +1,221 @@ +name: Issue Report +description: File a report for ROCm related issues on Linux and Windows. For issues pertaining to documentation or non-bug related, please open a blank issue located below. +title: "[Issue]: " + +body: +- type: markdown + attributes: + value: | + Thank you for taking the time to fill out this report! + + You can acquire your OS, CPU, GPU (for filling out this report) with the following commands: + + Linux: + echo "OS:" && cat /etc/os-release | grep -E "^(NAME=|VERSION=)"; + echo "CPU: " && cat /proc/cpuinfo | grep "model name" | sort --unique; + echo "GPU:" && /opt/rocm/bin/rocminfo | grep -E "^\s*(Name|Marketing Name)"; + + Windows: + (Get-WmiObject Win32_OperatingSystem).Version + (Get-WmiObject win32_Processor).Name + (Get-WmiObject win32_VideoController).Name +- type: textarea + attributes: + label: Problem Description + description: Describe the issue you encountered. + validations: + required: true +- type: input + attributes: + label: Operating System + description: What is the name and version number of the OS? + placeholder: "e.g. Ubuntu 22.04.3 LTS (Jammy Jellyfish)" + validations: + required: true +- type: input + attributes: + label: CPU + description: What CPU did you encounter the issue on? + placeholder: "e.g. AMD Ryzen 9 5900HX with Radeon Graphics" + validations: + required: true +- type: dropdown + attributes: + label: GPU + description: What GPU(s) did you encounter the issue on (you can select multiple GPUs from the list) + multiple: true + options: + - AMD Instinct MI300X + - AMD Instinct MI300A + - AMD Instinct MI300 + - AMD Instinct MI250X + - AMD Instinct MI250 + - AMD Instinct MI210 + - AMD Instinct MI100 + - AMD Instinct MI50 + - AMD Instinct MI25 + - AMD Radeon Pro V620 + - AMD Radeon Pro VII + - AMD Radeon RX 7900 XTX + - AMD Radeon VII + - AMD Radeon Pro W7900 + - AMD Radeon Pro W7800 + - AMD Radeon Pro W6800 + - AMD Radeon Pro W6600 + - AMD Radeon Pro W5500 + - AMD Radeon RX 7900 XT + - AMD Radeon RX 7600 + - AMD Radeon RX 6950 XT + - AMD Radeon RX 6900 XT + - AMD Radeon RX 6800 XT + - AMD Radeon RX 6800 + - AMD Radeon RX 6750 + - AMD Radeon RX 6700 XT + - AMD Radeon RX 6700 + - AMD Radeon RX 6650 XT + - AMD Radeon RX 6600 XT + - AMD Radeon RX 6600 + - Other + validations: + required: true +- type: input + attributes: + label: Other + description: If you selected Other, please specify +- type: dropdown + attributes: + label: ROCm Version + description: What version(s) of ROCm did you encounter the issue on? + multiple: true + options: + - ROCm 6.0.0 + - ROCm 5.7.1 + - ROCm 5.7.0 + - ROCm 5.6.1 + - ROCm 5.6.0 + - ROCm 5.5.1 + - ROCm 5.5.0 + validations: + required: true +- type: dropdown + attributes: + label: ROCm Component + description: (Optional) If this issue relates to a specific ROCm component, it can be mentioned here. + multiple: true + options: + - Other + - AMD Common Language Runtime + - AMD MIGraphX + - AMD System Management Interface + - amdgpu KCL/autoconf + - amdgpu Kernel-mode GPU Driver + - amdgpu-install + - AOMP + - AOMP Extras + - AqlProfile + - build-infra + - chelsio + - clang-ocl + - Composable Kernel + - dkms + - docker / ROCm-docker + - flang + - gpuburn + - half + - HIP + - HIP Examples + - hipBLAS + - hipBLASLt + - HIPCC + - hipCUB + - hip-examples-private + - hipFFT + - hipfort + - HIPIFY + - hipRAND + - hipSOLVER + - hipSPARSE + - hipSPARSELt + - hipTensor + - hip-tests + - HSA Runtime + - infrastructure + - jenkins-utils + - libdrm + - Linux BPI packaging framework + - llvm-project + - Mesa + - meta + - MIOpen + - MIVisionX + - ml-framework-ci + - MLSEQA_TestRepo + - OpenCL API C++ Bindings + - OpenCL API Headers + - OpenCL Conformance Test Suite + - OpenCL ICD Loader + - perftest-p2p + - prototype + - RCCL + - rccl-rdma-sharp-plugins + - rocALUTION + - rocBLAS + - ROCdbgapi + - ROCdebug-agent + - rocFFT + - ROCgdb + - ROCK + - ROCm Documentation/Website + - ROCm Data Center Tool + - ROCm Examples + - ROCm for Windows + - ROCm Performance Primitives + - ROCm System Management Interface Library + - ROCm Thrust + - ROCm Validation Suite + - rocm_bandwidth_test + - rocm-cmake + - rocm-core + - rocm-docs-core + - rocminfo + - rocMLIR + - rocmtools + - rocPRIM + - rocprofiler + - rocRAND + - ROCR-Runtime + - rocSOLVER + - rocSPARSE + - roctracer + - ROCT-Thunk-Interface + - rocWMMA + - Tensile + - umr + - ibv_rc_pingpong-amd + - mellanox + - mpitest + - Pytorch + - Tensorflow + - APEX + - torchvision + - Magma +- type: textarea + attributes: + label: Steps to Reproduce + description: (Optional) Detailed steps to reproduce the issue. + validations: + required: false + +- type: textarea + attributes: + label: (Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support + description: The output of rocminfo --support could help to better address the problem. + validations: + required: false + +- type: textarea + attributes: + label: Additional Information + description: (Optional) Any additional information that is relevant, e.g. relevant environment variables, dockerfiles, log files, dmesg output (on Linux), etc. + validations: + required: false diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 276690bd4f96f50eb6a5121d237c8c9bcb80224d..0e0a252eb6ee9d8508d2b21f42df3c77a1739832 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,3 +10,9 @@ updates: open-pull-requests-limit: 10 schedule: interval: "daily" + labels: + - "documentation" + - "dependencies" + - "ci:docs-only" + reviewers: + - "samjwu" diff --git a/.gitignore b/.gitignore index 4a668630bcf6daaf79ab2df6f8ebe98860d47c8f..b3653bb3a103d2876912eee25687b56eaf9c0601 100644 --- a/.gitignore +++ b/.gitignore @@ -61,8 +61,16 @@ _images/ _static/ _templates/ _toc.yml -docBin/ _doxygen/ # pycache __pycache__/ + +# JetBrains IDE +.idea/ +cmake-build*/ +build*/ + +# Python virtualenv +.venv/ + diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 5f50df2525d2948bbabe36184f23310d47339d8e..9e6678abe5247868c15fb7b47497a9154e4a6051 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,11 +3,6 @@ version: 2 -build: - os: ubuntu-22.04 - tools: - python: "3.8" - sphinx: configuration: docs/conf.py @@ -16,3 +11,8 @@ formats: [htmlzip, pdf, epub] python: install: - requirements: docs/sphinx/requirements.txt + +build: + os: ubuntu-22.04 + tools: + python: "3.8" diff --git a/CHANGELOG.md b/CHANGELOG.md index 31f129b58159076c3d3027e3c36ccc885e4b1baf..4e3feed2dfbc81961e12016a8f0abcd015a1835b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,32 +1,67 @@ -# Change Log for Composable Kernel +# Changelog for Composable Kernel Full documentation for Composable Kernel is not yet available. -## CK 0.2.0 for ROCm 5.5.0 +## (Unreleased) CK -### Fixed -- Fixed a bug in 6-dimensional kernels (#555). -- Fixed grouped ConvBwdWeight test case failure (#524). +### Fixes +None ### Optimizations -- Improve proformance of normalization kernel - -### Added -- Added new cmake flag "DL_KERNELS" must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances. -- Added new cmake flag "DTYPES" which could be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instance of select data types. -- Added new cmake flag "INSTANCES_ONLY" which will only build CK library and instances without the tests, examples, or profiler. -- Added new feature: if GPU_TARGETS is not set on cmake command line, CK will be built for all targets supported by compiler. -- Added support on MI300A/MI300X. -- Added support on NAVI3x. -- Added user tutorial (#563). -- Added more instances for irregular GEMM sizes (#560). -- Added inter-wave consumer-producer programming model for GEMM kernels (#310). -- Added multi-D GEMM client APIs (#534). -- Added multi-embeddings support (#542). -- Added Navi3x blockwise GEMM and real GEMM support (#541). -- Added Navi grouped ConvBwdWeight support (#505). -- Added MaxPool, AvgPool forward (#815). -- Added MaxPool backward (#750). - -### Changed -- Changed ... +None + +### Additions +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139) + +### Changes +None + +## CK for ROCm 6.0.0 + +### Fixes + * Fixed a hazard associated with inline v_dot (#808) + * Fixed two bugs in grouped convolution backward data without K padding (#848 #876) + +### Optimizations +None + +### Additions +* Added an image to a column kernel (#867) +* Added a column to an image kernel (#930) +* Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) +* Grouped convolution support for small K and C (#822 #879 #897) +* Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) +* Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) +* Support for Batched Gemm DL (#732) + +### Changes + * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + +## CK 0.2.0 for ROCm 5.7.0 + +### Fixes +* Fixed a bug in 6-dimensional kernels (#555) +* Fixed a test case failure with grouped convolution backward weight (#524) + +### Optimizations +* Improved the performance of the normalization kernel + +### Additions +* New CMake flags: + * "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler +* New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler +* Support for MI300A/MI300X +* Support for AMD RDNA 3 +* New user tutorial (#563) +* Additional instances for irregular GEMM sizes (#560) +* New inter-wave consumer-producer programming model for GEMM kernels (#310) +* GEMM with support multiple elementwise fusions (multi-D) (#534) +* Multi-embeddings support (#542) +* AMD RDNA 3 blockwise GEMM and real GEMM support (#541) +* AMD RDNA grouped convolution backward weight support (#505) +* MaxPool and AvgPool forward (#815); MaxPool backward (#750) + +### Changes +None diff --git a/CITATION.cff b/CITATION.cff index d35fe9e5870f0a538c1775efd239dc729e11ab88..3813d63812566760017b6f2d3bfeec0ecd11f884 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -59,9 +59,9 @@ authors: family-names: Zhou - given-names: Jianfeng family-names: Yan -repository-code: 'https://github.com/ROCmSoftwarePlatform/composable_kernel' +repository-code: 'https://github.com/ROCm/composable_kernel' abstract: Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for Machine Learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel progarmming languages, like HIP C++. keywords: - 'CK, Composable Kernel, Tensor Coordinate Transformation' license: MIT -license-url: https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE +license-url: https://github.com/ROCm/composable_kernel/blob/7fc3ed761aa35709d87c8fbbe41dd368648b3541/LICENSE diff --git a/CMakeLists.txt b/CMakeLists.txt index 7ef5e1ca073003a8a130743541f8fa63b8505cc4..eefae4ae152b331ea6e8fff837ec48a6eb9a3460 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,30 @@ cmake_minimum_required(VERSION 3.14) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() + +get_property(_GENERATOR_IS_MULTI_CONFIG GLOBAL PROPERTY GENERATOR_IS_MULTI_CONFIG) + +# This has to be initialized before the project() command appears +# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE +if(_GENERATOR_IS_MULTI_CONFIG) + set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo;MinSizeRel" CACHE STRING + "Available build types (configurations) on multi-config generators") +else() + set(CMAKE_BUILD_TYPE Release CACHE STRING + "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.") +endif() + +# Default installation path +if(NOT WIN32) + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") +endif() set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel VERSION ${version}) +project(composable_kernel VERSION ${version} LANGUAGES CXX) +include(CTest) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") @@ -16,6 +38,10 @@ if (DTYPES) add_definitions(-DCK_ENABLE_FP8) set(CK_ENABLE_FP8 "ON") endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") + endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) set(CK_ENABLE_FP16 "ON") @@ -34,10 +60,15 @@ if (DTYPES) endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_ALL_DTYPES "ON") endif() +#for f8/bf8_t type +add_compile_options(-Wno-bit-int-extension) +add_compile_options(-Wno-pass-failed) +add_compile_options(-Wno-switch-default) + if(DL_KERNELS) add_definitions(-DDL_KERNELS) set(CK_ENABLE_DL_KERNELS "ON") @@ -48,15 +79,15 @@ if(INSTANCES_ONLY) set(CK_ENABLE_INSTANCES_ONLY "ON") endif() +include(getopt) + # CK config file to record supported datatypes, etc. -configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h") +configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) # CK version file to record release version as well as git commit hash find_package(Git REQUIRED) execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) -configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h") - -enable_testing() +configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h) set(ROCM_SYMLINK_LIBS OFF) find_package(ROCM REQUIRED PATHS /opt/rocm) @@ -72,7 +103,7 @@ include(TargetFlags) rocm_setup_version(VERSION ${version}) -list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") message("GPU_TARGETS= ${GPU_TARGETS}") @@ -82,26 +113,30 @@ message("checking which targets are supported") #Setting GPU_TARGETS on command line will override this list if(NOT PROFILER_ONLY) rocm_check_target_ids(DEFAULT_GPU_TARGETS - TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() add_definitions(-DPROFILER_ONLY) + set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") endif() - if(GPU_ARCH MATCHES "gfx9") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942") + if(GPU_ARCH MATCHES "gfx90") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") + elseif(GPU_ARCH MATCHES "gfx94") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942") elseif(GPU_ARCH MATCHES "gfx10") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") endif() + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ") +set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) if(GPU_TARGETS) message("Building CK for the following targets: ${GPU_TARGETS}") @@ -143,31 +178,36 @@ if (NOT CK_BUILD_JIT_LIB) # SWDEV-413293 and https://reviews.llvm.org/D155213 math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") message("hip_version_flat=${hip_VERSION_FLAT}") - if(${hip_VERSION_FLAT} GREATER 500723302) + if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) message("Adding the fno-offload-uniform-block compiler flag") add_compile_options(-fno-offload-uniform-block) endif() - option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) - option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) - - if(USE_BITINT_EXTENSION_INT4) - add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) - add_compile_options(-Wno-bit-int-extension) - message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") + # + # Seperate linking jobs from compiling + # Too many concurrent linking jobs can break the build + # Copied from LLVM + set(CK_PARALLEL_LINK_JOBS "" CACHE STRING + "Define the maximum number of concurrent link jobs (Ninja only).") + if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_LINK_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS link_job_pool=${CK_PARALLEL_LINK_JOBS}) + set(CMAKE_JOB_POOL_LINK link_job_pool) + endif() + elseif(CK_PARALLEL_LINK_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") endif() - - if(USE_OPT_NAVI3X) - add_compile_options(-mcumode) - add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") + # Similar for compiling + set(CK_PARALLEL_COMPILE_JOBS "" CACHE STRING + "Define the maximum number of concurrent compile jobs (Ninja only).") + if(CMAKE_GENERATOR MATCHES "Ninja") + if(CK_PARALLEL_COMPILE_JOBS) + set_property(GLOBAL APPEND PROPERTY JOB_POOLS compile_job_pool=${CK_PARALLEL_COMPILE_JOBS}) + set(CMAKE_JOB_POOL_COMPILE compile_job_pool) + endif() + elseif(CK_PARALLEL_COMPILE_JOBS) + message(WARNING "Job pooling is only available with Ninja generators.") endif() - - ## Threads - set(THREADS_PREFER_PTHREAD_FLAG ON) - find_package(Threads REQUIRED) - link_libraries(Threads::Threads) - ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") # workaround issue hipcc in rocm3.5 cannot find openmp @@ -211,7 +251,11 @@ if (NOT CK_BUILD_JIT_LIB) endif() message(STATUS "Build with HIP ${HIP_VERSION}") link_libraries(hip::device) - add_compile_definitions(__HIP_PLATFORM_HCC__=1) + if(CK_hip_VERSION VERSION_GREATER_EQUAL 6.0.23494) + add_compile_definitions(__HIP_PLATFORM_AMD__=1) + else() + add_compile_definitions(__HIP_PLATFORM_HCC__=1) + endif() endif() ## tidy @@ -387,32 +431,28 @@ if (NOT CK_BUILD_JIT_LIB) set(cmake_instance) file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) set(add_inst 0) - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") - #message("fp8 instance found!") + if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") - #message("fp16 instance found!") + if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") - #message("fp32 instance found!") + if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") - #message("fp64 instance found!") + if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") - #message("bf16 instance found!") + if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") - #message("int8 instance found!") + if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") set(add_inst 1) endif() if(NOT "${cmake_instance}" MATCHES "DTYPES") - #message("instance should be built for all types!") set(add_inst 1) endif() if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) @@ -436,18 +476,20 @@ if (NOT CK_BUILD_JIT_LIB) PACKAGE_NAME examples ) add_subdirectory(example) - add_subdirectory(test) + if(BUILD_TESTING) + add_subdirectory(test) + endif() rocm_package_setup_component(profiler LIBRARY_NAME composablekernel - PACKAGE_NAME ckProfiler + PACKAGE_NAME ckProfiler ) add_subdirectory(profiler) else() #When building PROFILER_ONLY, label the package with GPU_ARCH rocm_package_setup_component(profiler LIBRARY_NAME composablekernel - PACKAGE_NAME ckProfiler_${GPU_ARCH} + PACKAGE_NAME ckProfiler_${GPU_ARCH} ) add_subdirectory(profiler) endif() @@ -483,7 +525,7 @@ rocm_install(FILES ) # Install CK version and configuration files -install(FILES +rocm_install(FILES ${PROJECT_BINARY_DIR}/include/ck/version.h ${PROJECT_BINARY_DIR}/include/ck/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/ diff --git a/Config.cmake.in b/Config.cmake.in index 0a19915565f1b2d719f8c0687689daf31f5fc8c4..b5fe3b847315d8882d71e541156ab40b7cb1e26a 100644 --- a/Config.cmake.in +++ b/Config.cmake.in @@ -1,6 +1,6 @@ @PACKAGE_INIT@ -set(_composable_kernel_supported_components device_operations utility jit_library) +set(_composable_kernel_supported_components device_other_operations device_gemm_operations device_conv_operations device_mha_operations device_contraction_operations device_reduction_operations utility jit_library) foreach(_comp ${composable_kernel_FIND_COMPONENTS}) if(NOT _comp IN_LIST _composable_kernel_supported_components) diff --git a/Dockerfile b/Dockerfile index e479268f48050957fabdaca7cfc5d162610f4fc1..38f234943c9324add594a554fa58ea7da554b9bf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:20.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=5.6 +ARG ROCMVERSION=6.0 ARG compiler_version="" ARG compiler_commit="" @@ -16,72 +16,79 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN wget https://repo.radeon.com/amdgpu-install/5.6/ubuntu/focal/amdgpu-install_5.6.50600-1_all.deb --no-check-certificate -RUN apt-get update && \ -DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - ./amdgpu-install_5.6.50600-1_all.deb - -RUN if [ "$ROCMVERSION" != "5.7" ]; then \ +RUN if [ "$ROCMVERSION" != "6.0.1" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "" ] || [ "$compiler_version" = "amd-stg-open" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - amdgpu-repo --amdgpu-build=1609671 --rocm-build=compute-rocm-npi-mi300/1354; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.7 rel-19 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1637781; \ + elif [ "$ROCMVERSION" = "6.0.1" ] && [ "$compiler_version" = "rc1" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.0-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.0-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.0.1 rel-95 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1704947; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms +## Sccache binary built from source for ROCm +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +RUN mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ +curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ +chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} + # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ - ccache \ cmake \ git \ hip-rocclr \ + iputils-ping \ jq \ libelf-dev \ libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ + net-tools \ pkg-config \ python \ python3 \ python3-dev \ python3-pip \ + redis \ sshpass \ + stunnel \ software-properties-common \ vim \ nano \ zlib1g-dev \ + zip \ openssh-server \ clang-format-12 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* -#Install latest version of cmake +#Install latest ccache +RUN git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install + +#Install ninja build tracing tools RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja RUN git clone https://github.com/nico/ninjatracing.git -RUN apt purge --auto-remove -y cmake -RUN apt update -RUN apt install -y software-properties-common lsb-release -RUN apt clean all -RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null -RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -RUN apt install -y kitware-archive-keyring -RUN rm /etc/apt/trusted.gpg.d/kitware.gpg -RUN apt install -y cmake +# Update the cmake to the latest version +RUN pip install --upgrade cmake==3.27.5 + +#Install latest cppcheck +RUN git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . +WORKDIR / # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer @@ -96,9 +103,9 @@ ARG PREFIX=/opt/rocm RUN pip3 install --upgrade pip RUN pip3 install sqlalchemy==1.4.46 RUN pip3 install pymysql -RUN pip3 install pandas +RUN pip3 install pandas==2.0.3 RUN pip3 install setuptools-rust -RUN pip3 install sshtunnel +RUN pip3 install sshtunnel==0.4.0 # Setup ubsan environment to printstacktrace ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -107,7 +114,7 @@ ENV LANG=C.UTF-8 RUN groupadd -f render # Install the new rocm-cmake version -RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ +RUN git clone -b master https://github.com/ROCm/rocm-cmake.git && \ cd rocm-cmake && mkdir build && cd build && \ cmake .. && cmake --build . && cmake --build . --target install @@ -118,22 +125,24 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi -RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi +#clean-up the deb package +RUN sh -c "rm -rf amdgpu-install*" #ENV HIP_CLANG_PATH='/llvm-project/build/bin' #RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/Jenkinsfile b/Jenkinsfile index 668d9a613bf3c0120076875b3882941530b66ae2..becdc35b1687a4b57a1aa590a77e8438adb22198 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,5 +1,5 @@ def rocmnode(name) { - return 'rocmtest && miopen && ' + name + return '(rocmtest || miopen) && ' + name } def show_node_info() { @@ -33,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "5.7"){ + if (params.ROCMVERSION != "6.0.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -65,10 +65,10 @@ def getDockerImageName(){ } def check_host() { - if ("${env.CK_CCACHE}" != "null"){ - def CCACHE_SERVER="${env.CK_CCACHE.split(':')[0]}" - echo "ccache server: ${CCACHE_SERVER}" - sh '''ping -c 1 -p 6379 "${CCACHE_SERVER}" | echo $? > tmp.txt''' + if ("${env.CK_SCCACHE}" != "null"){ + def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" + echo "sccache server: ${SCCACHE_SERVER}" + sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" return (output != "0") @@ -84,7 +84,7 @@ def build_compiler(){ compiler = '/opt/rocm/bin/hipcc' } else{ - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ compiler = "/llvm-project/build/bin/clang++" } else{ @@ -96,24 +96,9 @@ def build_compiler(){ def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 - def prefixpath = conf.get("prefixpath", "/opt/rocm") // prefix:/opt/rocm + def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" - } if(no_cache) { dockerArgs = dockerArgs + " --no-cache " @@ -142,21 +127,6 @@ def buildDocker(install_prefix){ def image_name = getDockerImageName() echo "Building Docker for ${image_name}" def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" - } echo "Build Args: ${dockerArgs}" try{ @@ -164,18 +134,23 @@ def buildDocker(install_prefix){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } + sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' } else{ echo "Checking for image: ${image_name}" sh "docker manifest inspect --insecure ${image_name}" - echo "Image: ${image_name} found!! Skipping building image" + echo "Image: ${image_name} found! Skipping building image" } } catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } } } @@ -210,19 +185,18 @@ def cmake_build(Map conf=[:]){ } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } + if (params.DL_KERNELS){ + setup_args = setup_args + " -DDL_KERNELS=ON " + } if(build_type_debug){ setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args }else{ setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args } - if(env.CK_CCACHE) - { - setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER='ccache' -DCMAKE_C_COMPILER_LAUNCHER='ccache' " + setup_args - } - echo "ccache server: ${env.CK_CCACHE}" def pre_setup_cmd = """ + #!/bin/bash echo \$HSA_ENABLE_SDMA ulimit -c unlimited rm -rf build @@ -231,6 +205,60 @@ def cmake_build(Map conf=[:]){ mkdir install cd build """ + def invocation_tag="" + if (setup_args.contains("gfx11")){ + invocation_tag="gfx11" + } + if (setup_args.contains("gfx10")){ + invocation_tag="gfx10" + } + if (setup_args.contains("gfx90")){ + invocation_tag="gfx90" + } + if (setup_args.contains("gfx94")){ + invocation_tag="gfx94" + } + echo "invocation tag: ${invocation_tag}" + def redis_pre_setup_cmd = pre_setup_cmd + if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { + redis_pre_setup_cmd = pre_setup_cmd + """ + #!/bin/bash + export ROCM_PATH=/opt/rocm + export SCCACHE_ENABLED=true + export SCCACHE_LOG_LEVEL=debug + export SCCACHE_IDLE_TIMEOUT=14400 + export COMPILERS_HASH_DIR=/tmp/.sccache + export SCCACHE_BIN=/usr/local/.cargo/bin/sccache + export SCCACHE_EXTRAFILES=/tmp/.sccache/rocm_compilers_hash_file + export SCCACHE_REDIS="redis://${env.CK_SCCACHE}" + echo "connect = ${env.CK_SCCACHE}" >> ../script/redis-cli.conf + export SCCACHE_C_CUSTOM_CACHE_BUSTER="${invocation_tag}" + echo \$SCCACHE_C_CUSTOM_CACHE_BUSTER + stunnel ../script/redis-cli.conf + ../script/sccache_wrapper.sh --enforce_redis + """ + try { + def cmd1 = conf.get("cmd1", """ + ${redis_pre_setup_cmd} + """) + sh cmd1 + setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args + } + catch(Exception err){ + echo "could not connect to redis server: ${err.getMessage()}. will not use sccache." + def cmd2 = conf.get("cmd2", """ + ${pre_setup_cmd} + """) + sh cmd2 + } + } + else{ + def cmd3 = conf.get("cmd3", """ + ${pre_setup_cmd} + """) + sh cmd3 + } + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") // reduce parallelism when compiling, clang uses too much memory def nt = nthreads() @@ -238,17 +266,19 @@ def cmake_build(Map conf=[:]){ def execute_cmd = conf.get("execute_cmd", "") def cmd = conf.get("cmd", """ - ${pre_setup_cmd} ${setup_cmd} ${build_cmd} ${execute_cmd} """) echo cmd - sh cmd + + dir("build"){ + sh cmd + } // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } } @@ -268,7 +298,7 @@ def buildHipClangJob(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -277,9 +307,9 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { - timeout(time: 5, unit: 'HOURS') + timeout(time: 48, unit: 'HOURS') { cmake_build(conf) } @@ -323,14 +353,14 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -367,8 +397,6 @@ def runCKProfiler(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //cmake_build(conf) - //instead of building, just unstash the ckProfiler and install it sh """ rm -rf build mkdir build @@ -456,7 +484,7 @@ def Build_CK(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-stg-open" || params.COMPILER_COMMIT != ""){ + if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -464,7 +492,7 @@ def Build_CK(Map conf=[:]){ def retimage def navi_node = 0 - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -525,6 +553,26 @@ def Build_CK(Map conf=[:]){ stash "ckprofiler_0.2.0_amd64.deb" } } + if (params.hipTensor_test && navi_node == 0 ){ + //build and test hipTensor + sh """#!/bin/bash + rm -rf "${params.hipTensor_branch}".zip + rm -rf hipTensor-"${params.hipTensor_branch}" + wget https://github.com/ROCm/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip + unzip -o "${params.hipTensor_branch}".zip + """ + dir("hipTensor-${params.hipTensor_branch}"){ + sh """#!/bin/bash + mkdir -p build + ls -ltr + CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" + cmake --build build -- -j + """ + } + dir("hipTensor-${params.hipTensor_branch}/build"){ + sh 'ctest' + } + } } } } @@ -562,7 +610,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } @@ -612,9 +660,10 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 - 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.0;COMPILER_VERSION= + 0 21 * * * % ROCMVERSION=6.0;COMPILER_VERSION=;COMPILER_COMMIT= + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false + 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" pipeline { agent none @@ -631,24 +680,48 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.6', - description: 'Specify which ROCM version to use: 5.6 (default).') + defaultValue: '6.0', + description: 'Specify which ROCM version to use: 6.0 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', - description: 'Specify which version of compiler to use: release, amd-stg-open, or leave blank (default).') + description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline-open, or leave blank (default).') string( name: 'COMPILER_COMMIT', defaultValue: '', - description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') + description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( name: 'BUILD_COMPILER', - defaultValue: 'hipcc', - description: 'Specify whether to build CK with hipcc (default) or with clang.') + defaultValue: 'clang', + description: 'Specify whether to build CK with hipcc or with clang (default).') booleanParam( name: "RUN_FULL_QA", defaultValue: false, description: "Select whether to run small set of performance tests (default) or full QA") + booleanParam( + name: "DL_KERNELS", + defaultValue: false, + description: "Select whether to build DL kernels (default: OFF)") + booleanParam( + name: "hipTensor_test", + defaultValue: true, + description: "Use the CK build to verify hipTensor build and tests (default: ON)") + string( + name: 'hipTensor_branch', + defaultValue: 'mainline', + description: 'Specify which branch of hipTensor to use (default: mainline)') + booleanParam( + name: "USE_SCCACHE", + defaultValue: true, + description: "Use the sccache for building CK (default: ON)") + booleanParam( + name: "RUN_CPPCHECK", + defaultValue: false, + description: "Run the cppcheck static analysis (default: OFF)") + booleanParam( + name: "RUN_PERFORMANCE_TESTS", + defaultValue: false, + description: "Run the performance tests (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -663,22 +736,51 @@ pipeline { } stages{ stage("Build Docker"){ - //when { - // beforeAgent true - // expression { params.BUILD_DOCKER.toBoolean() } - //} parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } steps{ buildDocker('/opt/rocm') + cleanWs() } } } } stage("Static checks") { parallel{ + stage('Clang Format and Cppcheck') { + when { + beforeAgent true + expression { params.RUN_CPPCHECK.toBoolean() } + } + agent{ label rocmnode("nogpu") } + environment{ + execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ + -o -not -path \'*.git*\' -iname \'*.hpp\' \ + -o -not -path \'*.git*\' -iname \'*.cpp\' \ + -o -iname \'*.h.in\' \ + -o -iname \'*.hpp.in\' \ + -o -iname \'*.cpp.in\' \ + -o -iname \'*.cl\' \ + | grep -v 'build/' \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ + -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 -D DL_KERNELS \ + -D __gfx908__ -D __gfx90a__ -D __gfx940__ -D __gfx941__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ + -U __gfx803__ -U __gfx900__ -U __gfx906__ -U CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 \ + --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" + } + steps{ + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + archiveArtifacts "build/ck_cppcheck.log" + cleanWs() + } + } stage('Clang Format') { + when { + beforeAgent true + expression { !params.RUN_CPPCHECK.toBoolean() } + } agent{ label rocmnode("nogpu") } environment{ execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ @@ -693,6 +795,7 @@ pipeline { } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + cleanWs() } } } @@ -710,11 +813,19 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ + -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on MI100/MI200") @@ -725,11 +836,16 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx908;gfx90a" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on Navi21") @@ -740,12 +856,16 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1030" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on Navi32") @@ -756,12 +876,16 @@ pipeline { } agent{ label rocmnode("navi32") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx1101" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } } @@ -775,7 +899,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() } + expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx908 || gfx90a")} @@ -784,13 +908,14 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } stage("Run ckProfiler: gfx90a") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() } + expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } } options { retry(2) } agent{ label rocmnode("gfx90a")} @@ -799,6 +924,7 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } } @@ -808,9 +934,14 @@ pipeline { parallel { stage("Process results"){ + when { + beforeAgent true + expression { params.RUN_PERFORMANCE_TESTS.toBoolean() } + } agent { label 'mici' } steps{ process_results() + cleanWs() } } } diff --git a/LICENSE b/LICENSE index e03fddaf78080705d26ec277629cfb8010c077bf..581b5efde535f686aa0a584709fccaa92353d125 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index c2b493db11596d489c6bc5e4fcf9549080998c01..488991469102f90ef5ec3d1ee89c6d9277ebc5e2 100644 --- a/README.md +++ b/README.md @@ -1,139 +1,188 @@ # Composable Kernel -## Methodology +The Composable Kernel (CK) library provides a programming model for writing performance-critical +kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library +uses general purpose kernel languages, such as HIP C++. -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +CK uses two concepts to achieve performance portability and code maintainability: -CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". +* Algorithm complexity reduction for complex machine learning (ML) operators. This uses an innovative + technique called *Tensor Coordinate Transformation*. ![ALT](/docs/data/ck_component.png "CK Components") -## Code Structure +The current CK library is structured into four layers: -Current CK library are structured into 4 layers: -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer +* Templated Tile Operators +* Templated Kernel and Invoker +* Instantiated Kernel and Invoker +* Client API ![ALT](/docs/data/ck_layer.png "CK Layers") -## Documentation +## General information -Run the steps below to build documentation locally. +To build our documentation locally, use the following code: -``` +``` bash cd docs pip3 install -r sphinx/requirements.txt python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` -## Contributors - -The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) +You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. -## Citation +```note +If you use CK, cite us as follows: -If you use CK, please use following citations: -* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) +* [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???): + This paper will be available on arXiv soon. * [CITATION.cff](/CITATION.cff) +``` -## License +CK is released under the **[MIT license](/LICENSE)**. -CK is released under the MIT license. [License File](/LICENSE) +## Building CK +We recommend building CK inside Docker containers, which include all necessary packages. Pre-built +Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composable_kernel/tags). -# Build CK +1. To build a new Docker image, use the Dockerfile provided with the source code: -## Build docker image + ```bash + DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . + ``` -```bash -DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . -``` -Pre-built dockers are available from this public repo: -https://hub.docker.com/r/rocm/composable_kernel/tags +2. Launch the Docker container: -## Launch docker + ```bash + docker run \ + -it \ + --privileged \ + --group-add sudo \ + -w /root/workspace \ + -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ + ck:latest \ + /bin/bash + ``` -```bash -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -ck:latest \ -/bin/bash -``` +3. Clone CK source code from the GitHub repository and start the build: -## Build CK + ```bash + git clone https://github.com/ROCm/composable_kernel.git && \ + cd composable_kernel && \ + mkdir build && \ + cd build + ``` -```bash -mkdir build && cd build + You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s) you want + to run CK on. You can specify single or multiple architectures. If you specify multiple architectures, + use a semicolon between each; for example, `gfx908;gfx90a;gfx940`. -# Need to specify target ID, example below is for gfx908 and gfx90a + ```bash + cmake \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx908;gfx90a" \ + .. + ``` -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_BUILD_TYPE=Release \ --D GPU_TARGETS="gfx908;gfx90a" \ -.. -``` + If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets + supported by the current compiler (this may take a long time). -If GPU_TARGETS is not set on the cmake command line, CK will be built for all targets supported by the -current compiler. +4. Build the entire CK library: + ```bash + make -j + ``` -Additional cmake flags can be used to significantly speed-up the build: +5. Install CK: -INSTANCES_ONLY (by default is OFF) must be set to ON in order to build only the instances and library -while skipping all tests, examples, and profiler. This is useful for libraries that use CK as a dependency. + ```bash + make -j install + ``` -DTYPES (by default not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances -of select data types only. Currently, building of int8 instances is taking a lot of time (the compiler fix is in the works). +## Optional post-install steps -DL_KERNELS (by default is OFF) must be set to ON in order to build the gemm_dl and batched_gemm_multi_d_dl -instances. Those instances are only needed for the NAVI2x platforms. +* Build examples and tests: -### Build examples and tests + ```bash + make -j examples tests + ``` -```bash - make -j examples tests - make test -``` +* Build and run all examples and tests: + + ```bash + make -j check + ``` -Instructions for running each individual examples are under [example](/example) + You can find instructions for running each individual example in [example](/example). +* Build ckProfiler: -## Build ckProfiler + ```bash + make -j ckProfiler + ``` + + You can find instructions for running ckProfiler in [profiler](/profiler). + +Note the `-j` option for building with multiple threads in parallel. This speeds up the build significantly. +Depending on the number of CPU cores and the amount of RAM on your system, you may want to +limit the number of threads. For example, if you have a 128-core CPU and 64 Gb of RAM. + +By default, `-j` launches one thread per CPU core, which can cause the build to run out of memory and +crash. In such cases, you can reduce the number of threads to 32 by using `-j32`. + +Additional cmake flags can be used to significantly speed-up the build: + +* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library + while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a + dependency and don't plan to run any examples or tests. + +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build + instances of select data types only. The main default data types are fp32 and fp16; you can safely skip + other data types. + +* `DL_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dl` or + `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most + other platforms have faster instances, such as `xdl` or `wmma`, available. + +## Using sccache for building + +The default CK Docker images come with a pre-installed version of sccache, which supports clang +being used as hip-compiler (" -x hip"). Using sccache can help reduce the time to re-build code from +hours to 1-2 minutes. In order to invoke sccache, you need to run: ```bash - make -j ckProfiler + sccache --start-server ``` -Instructions for running ckProfiler are under [profiler](/profiler) -## Install CK +then add the following flags to the cmake command line: ```bash -make install + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache ``` +You may need to clean up the build folder and repeat the cmake and make steps in order to take +advantage of the sccache during subsequent builds. + ## Using CK as pre-built kernel library -Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) +You can find instructions for using CK as a pre-built kernel library in [client_example](/client_example). -## Contributing +## Contributing to CK -When you contribute to Composable Kernel, make sure to run `clang-format` on all the changed files. We highly recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: +When you contribute to CK, make sure you run `clang-format` on all changed files. We highly +recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: ```bash sudo script/install_precommit.sh ``` -This way, `pre-commit` will add the appropriate hooks to your local repository and automatically run `clang-format` (and possibly additional checks) before any commit is created. +With this approach, `pre-commit` adds the appropriate hooks to your local repository and +automatically runs `clang-format` (and possibly additional checks) before any commit is created. If you need to uninstall hooks from the repository, you can do so by running the following command: @@ -141,14 +190,5 @@ If you need to uninstall hooks from the repository, you can do so by running the script/uninstall_precommit.sh ``` -If for any reason, you need to temporarily disable precommit hooks, you can add the `--no-verify` option to the `git commit` command. - -## Caveat -### Kernel Timing and Verification - -CK's own kernel timer will warn up kernel once, and then run it multiple times -to get average kernel time. For some kernels that use atomic add, this will cause -output buffer to be accumulated multiple times, causing verification failure. -To work around it, do not use CK's own timer and do verification at the same time. -CK's own timer and verification in each example and ckProfiler can be enabled or -disabled from command line. +If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the +`git commit` command. diff --git a/client_example/01_gemm/CMakeLists.txt b/client_example/01_gemm/CMakeLists.txt index 9e741192f90b8216e4b3abe32ae8971fb45ddfee..6c4103cda879128992bfe2b0ff4757f8f8b95a9c 100644 --- a/client_example/01_gemm/CMakeLists.txt +++ b/client_example/01_gemm/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_gemm gemm.cpp) -target_link_libraries(client_gemm PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index c37f208db1cee9bfff1fff469fd79d059fb179f0..11f92228734385432c73fe7c84e0d796fe5b7258 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -185,6 +185,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index ba2952022233b3ffd21fc245b8fd199a5ec5b096..772b699955bb06ebcad9a259528354c67322c3cb 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,13 +1,13 @@ add_custom_target(client_gemm_fastgelu_examples) add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) @@ -15,13 +15,13 @@ add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu clie add_custom_target(client_gemm_fastgelu_generic_examples) add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 756889562e84c66efb5f972621bcb61edda3af82..e845c120d8874d8486a1e7ccdd166ca7967b74b9 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 8d2a8c234aae63a3566478b5aa9588389a247d4e..e77b67c9058993739cd1466966c0c850b9a19a2f 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -197,6 +197,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index c02df018fd35c6d37a2c6f9b9fded6390c9afb19..7648da9cac16373cf62311857d096e9d96d037ea 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -190,6 +190,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index b38698d9064e6fa3cbd2e01c8818c0ebfe361cb3..94b4576f64506b09f5fa9d1e9f864a2375e7c650 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,5 @@ add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index 3d5fb6004844af269f7786ae05de8c20cc720633..93f8847c626a3451a873fc053629c874b0773bb2 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -200,6 +200,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index 7ffedfeef36839c1c8bb942ffae11e5f6dcadf22..cd4a95124cf2a47daba494451bdbe0e3fd98c274 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,15 @@ add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) diff --git a/client_example/05_layernorm/CMakeLists.txt b/client_example/05_layernorm/CMakeLists.txt index b582b485d4ce46951aaed98c6256e1c997388bd3..b7b3c830ed5beaab6bb508d99799c998dd80e340 100644 --- a/client_example/05_layernorm/CMakeLists.txt +++ b/client_example/05_layernorm/CMakeLists.txt @@ -1,2 +1,11 @@ -add_executable(client_layernorm2d layernorm2d.cpp) -target_link_libraries(client_layernorm2d PRIVATE composable_kernel::device_operations) +add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp) +target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm2d_bwd_gamma_beta layernorm2d_bwd_gamma_beta.cpp) +target_link_libraries(client_layernorm2d_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm2d_fwd layernorm2d_fwd.cpp) +target_link_libraries(client_layernorm2d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_layernorm4d_fwd layernorm4d_fwd.cpp) +target_link_libraries(client_layernorm4d_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/05_layernorm/layernorm2d_bwd_data.cpp b/client_example/05_layernorm/layernorm2d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f26cb6840716c04ea90df60d15972bbafa0c9dd --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_data.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * M * N); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(GammaDataType) * N + sizeof(MeanInvStdDataType) * M * 2 + + sizeof(DXDataType) * M * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98b394add63b80e337f21d5dab10b2b16d03ae54 --- /dev/null +++ b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/layernorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * M * N); + SimpleDeviceMem x_dev(sizeof(XDataType) * M * N); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * M); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * N); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * N); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * M * N + sizeof(XDataType) * M * N + + sizeof(MeanInvStdDataType) * M * 2 + sizeof(DGammaDataType) * N + + sizeof(DBetaDataType) * N; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d_fwd.cpp similarity index 74% rename from client_example/05_layernorm/layernorm2d.cpp rename to client_example/05_layernorm/layernorm2d_fwd.cpp index 4af4d7abe8ee6c5120f6060c78141021052cb619..420225b61352cfb84d0dbc9117533b9b0282bc32 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d_fwd.cpp @@ -7,17 +7,19 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/normalization.hpp" +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = ck::half_t; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -50,15 +52,19 @@ int main(int argc, char* argv[]) SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); - - using DeviceOp = ck::tensor_operation::device::DeviceNormalization; +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -84,14 +90,21 @@ int main(int argc, char* argv[]) {0, 1}, // gammaStrides {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -100,11 +113,19 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * M * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -129,6 +150,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() @@ -136,23 +158,34 @@ int main(int argc, char* argv[]) auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths {Stride, 1}, // xStrides - {1}, // gammaStrides - {1}, // betaStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/05_layernorm/layernorm4d_fwd.cpp b/client_example/05_layernorm/layernorm4d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fa408dc751d165eca209692905749034baa7a99c --- /dev/null +++ b/client_example/05_layernorm/layernorm4d_fwd.cpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = ck::half_t; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 256; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t C = 8; + + std::vector strideXY = {H * W * C, W * C, C, 1}; + std::vector strideGammaBeta = {0, W * C, C, 1}; + std::vector strideSaveMeanInvStd = {1}; + + SimpleDeviceMem x_device_buf(sizeof(XDataType) * N * H * W * C); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * H * W * C); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * H * W * C); + SimpleDeviceMem y_device_buf(sizeof(YDataType) * N * H * W * C); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths + strideXY, // xStrides + strideGammaBeta, // gammaStrides + strideGammaBeta, // betaStrides + strideXY, // yStrides + strideSaveMeanInvStd, // save_mean Strides + strideSaveMeanInvStd, // save_inv_std Strides + {1, 2, 3}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(XDataType) * N * H * W * C + sizeof(GammaDataType) * H * W * C + + sizeof(BetaDataType) * H * W * C + sizeof(YDataType) * N * H * W * C; + +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * 2; +#endif + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, C}, // lengths + strideXY, // xStrides + strideGammaBeta, // gammaStrides + strideGammaBeta, // betaStrides + strideXY, // yStrides + strideSaveMeanInvStd, // save_mean Strides + strideSaveMeanInvStd, // save_inv_std Strides + {1, 2, 3}, // reduceDims + 1e-4, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/06_softmax/CMakeLists.txt b/client_example/06_softmax/CMakeLists.txt index b38a0fd9e27570e62df7f701572a24fb4ee842f9..24d30f475ec8c44b41375d9d972f04e6aeec4364 100644 --- a/client_example/06_softmax/CMakeLists.txt +++ b/client_example/06_softmax/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_softmax4d softmax4d.cpp) -target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_operations) +target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations) diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index 2ccad27a88757b576a050d0558acc4108857cbd1..a62af766351a2577ecca48c3e6c1f656ce9c9d68 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -140,6 +140,7 @@ int main(int argc, char* argv[]) << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index fce7e91c1ef5b23fcfecfe99cc49263a1d779d76..40f1bba064a27f50c6a04a183aa05b8ac6dc6cb8 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,5 @@ add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 70be0101c6d92512ab28a72797d2b8c46fb55281..4d743a66f0acc327fbeb9cd68b006037929e2dfb 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -100,18 +100,18 @@ int main() SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * X * C); SimpleDeviceMem out(sizeof(OutDataType) * G * N * Wo * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index 57a210fa1f5f0d3f0eca014e08ccd119d29e4fd6..c5e51ad993472b5166ee7a62d962d5a497b6a93f 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -71,18 +71,18 @@ int main() SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 862b9ed5b70ee8a35c4ba44e2ccf1441e1532bd3..9472be07b57fb545d274a8e188bfa6f90179c317 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,5 @@ add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations) +target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations) +target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index ac11aad45de84b8a2e042d8a1a4bb059b889f3f2..65ad642ce28ebe0699cd80a95ee25383266c55c2 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) +target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index cd504e942e943b8f174442554eca379cd908ba6e..78db4f8aa5bced84460c4db05d488e9f0be33b09 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -80,7 +80,7 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index f4aa3666b1c15eccd37bdef549f8402bcd1b252b..4121e41af7d02981628bb80d1c7c6364515876aa 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -78,18 +78,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp index ebdbbf52c0ca257dd8f672a48b32f80fa5bb8816..ea5f1dbd5b7c3994cce418d9b8fc53121ee42f50 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -83,7 +83,7 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< NumDimSpatial, InLayout, WeiLayout, diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 9d60baee06fddb27451f1f51642bad738739c4bb..5b40298d6c761dddae1228f7242aa199b4b9acf1 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -79,18 +79,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index dd81d9ee6b6dc0ee453cac61d7e795c79ba3c9fa..0b78bbf272220a2ffedb775dd2e331aa44264b48 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -76,19 +76,19 @@ int main(int argc, char* argv[]) SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index 9c088a21d38e7f62c8eee35bf4c557a0bb2209bd..7315f2bb55d57868201ee15114f06746c0208bf3 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -72,18 +72,18 @@ int main(int argc, char* argv[]) SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - OutElementOp>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt b/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt deleted file mode 100644 index e564f3180d8c9295356d90266f2159de47aac060..0000000000000000000000000000000000000000 --- a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) -target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0cf308c6e15cfdb6687817e10f4df16e18d9b1b7 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp similarity index 100% rename from client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp rename to client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2f2ff41bcc8f3eaa901594da9957d4cc8ff9dbb --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2330228d1d8d9f3fbc41ede8fb58da6c916856cc --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough, + ck::bf8_t, + ck::f8_t>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt index 82162b606a6ee055cb985b960366a93fa4dd21e3..dddfabb787460d4899ddf8e2a9df8de4d83f5836 100644 --- a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt +++ b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -2,8 +2,10 @@ add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_f add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp) +add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp) -target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index 4292cded2071526381db36c31a5e660f7087cbce..1a36490ef4d279aae5fa25ea8b3a55819842df82 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -85,7 +85,9 @@ template + typename OutLayout, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_bwd_weight( const std::array& input_lengths, const std::array& input_strides, @@ -113,7 +115,9 @@ bool run_grouped_conv_bwd_weight( OutDataType, PassThrough, PassThrough, - PassThrough>; + PassThrough, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..098b7cd868797824b2ed6a16a7b4b052cb9136a5 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; + +int main() +{ + return run_grouped_conv_bwd_weight(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/12_elementwise_normalization/CMakeLists.txt b/client_example/12_elementwise_normalization/CMakeLists.txt index 1ba0e1279a402ae00a6e3aead9de6c12c61e14c9..738647de595ad0151a12751d0d9c6b72a765ae2c 100644 --- a/client_example/12_elementwise_normalization/CMakeLists.txt +++ b/client_example/12_elementwise_normalization/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp) -target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_operations) +target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index bc4a6fe0bfa9e118bbd6ba32ecc7dd68f3b8b2c3..8326f0758c20e76b8cfaf2b2e1829f3c28864d0b 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -142,6 +142,7 @@ int main() << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() diff --git a/client_example/13_batchnorm/CMakeLists.txt b/client_example/13_batchnorm/CMakeLists.txt index fc4f9d395c4a8b45e34db2ad0c5622484d379dba..420ea25752082117c6e99d8fdaebf3ed11883a3d 100644 --- a/client_example/13_batchnorm/CMakeLists.txt +++ b/client_example/13_batchnorm/CMakeLists.txt @@ -1,6 +1,6 @@ add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp) add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp) add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp) -target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations) -target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_operations) -target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_other_operations) +target_link_libraries(client_batchnorm_bwd_nhwc PRIVATE composable_kernel::device_other_operations) +target_link_libraries(client_batchnorm_infer_nhwc PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/14_instance_id/CMakeLists.txt b/client_example/14_instance_id/CMakeLists.txt index 87b2a9a0cbc82a82c9567266178ee431fb9149e7..6ba0e59f5aebeb33f63333d026eeaad24e2258fd 100644 --- a/client_example/14_instance_id/CMakeLists.txt +++ b/client_example/14_instance_id/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp) -target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_operations) +target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index 8a60a71674f84f6005deadf3744b7df346c0ac3c..f35cd82d798a62cedf37a38028d4e9c345525e32 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,5 @@ add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) +target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index fd2dcf9614016b93b43a52c2940d1701ee623584..4b4d762003b1d0d9f58ce42183ecd670a7effdff 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,3 @@ add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_operations) \ No newline at end of file +target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index c74d7c6bd8cf9cec54168792a8718768827d7b33..cde4713b2365c289bb85556ad544c37e2b00e40e 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -204,6 +204,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/15_reduce/CMakeLists.txt b/client_example/15_reduce/CMakeLists.txt index d52675ba834e666d0a480d5c778baeacb223a24a..a944af5e54d27b014956951a78daa2bacac9573b 100644 --- a/client_example/15_reduce/CMakeLists.txt +++ b/client_example/15_reduce/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp) -target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_operations) +target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_reduction_operations) diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2580a370ca480e0b9ec101e5715ac9a631a72f6..5279e3dfcf8196533f59c419d2b3c81d0b53c249 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -1,5 +1,15 @@ -add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) -add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) +if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) + target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) +endif() + +if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index 449c9466e829baa8658f17fe8e9ae86c5a297238..a5b7c5b42e7901fd5bb75d252b50f73fe61cb589 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -11,7 +11,7 @@ #include "ck/ck.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -94,7 +94,8 @@ template + ck::index_t NumNonSpatialDim = 3, + typename ComputeType = InDataType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -173,18 +174,19 @@ bool run_grouped_conv_fwd(std::array(out_lengths, wei_lengths); std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, - OutLayout, - InDataType, - WeiDataType, - ck::Tuple<>, - OutDataType, - PassThrough, - PassThrough, - PassThrough>; + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough, + ComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1651ec2f39356dd0edec755385c3716342502077 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index 659e6769d8e69f227ea1ef5920c3951bba706f0d..fd315afbd2503f39a3f7b58c1b7d41311d5e5e58 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations) \ No newline at end of file +target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file diff --git a/client_example/18_groupnorm/CMakeLists.txt b/client_example/18_groupnorm/CMakeLists.txt index 17c88cb61bc5daaa3ac0c6a92459c147691be2fe..e04c26d8e7bc220825912e0a38297b75b6a93370 100644 --- a/client_example/18_groupnorm/CMakeLists.txt +++ b/client_example/18_groupnorm/CMakeLists.txt @@ -1,2 +1,8 @@ -add_executable(client_groupnorm_swish groupnorm_swish.cpp) -target_link_libraries(client_groupnorm_swish PRIVATE composable_kernel::device_operations) +add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp) +target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_bwd_gamma_beta groupnorm_bwd_gamma_beta.cpp) +target_link_libraries(client_groupnorm_bwd_gamma_beta PRIVATE composable_kernel::device_other_operations) + +add_executable(client_groupnorm_swish_fwd groupnorm_swish_fwd.cpp) +target_link_libraries(client_groupnorm_swish_fwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/18_groupnorm/groupnorm_bwd_data.cpp b/client_example/18_groupnorm/groupnorm_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01ca21ba5788914aba76b49df34b214f99b4a6be --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_data.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_data.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DXDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideDx = strideDy; + + std::vector strideGamma = {0, 0, 0, C, 1}; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem gamma_dev(sizeof(GammaDataType) * G * C); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dx_dev(sizeof(DXDataType) * length); + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationBwdData; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DXDataType) * length; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideGamma, + strideMeanInvStd, + strideMeanInvStd, + strideDx, + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2fbe285df74b452bb6965e35aa68273a6b1c43f --- /dev/null +++ b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" + +#include "ck/library/tensor_operation_instance/gpu/groupnorm_bwd_gamma_beta.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 32; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 64; + ck::index_t C = 128; + + std::size_t length = N * H * W * G * C; + + std::vector strideDy = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector strideX = strideDy; + std::vector strideMeanInvStd = {G, 0, 0, 1, 0}; + std::vector strideDGammaBeta = {C, 1}; + + SimpleDeviceMem dy_dev(sizeof(DYDataType) * length); + SimpleDeviceMem x_dev(sizeof(XDataType) * length); + SimpleDeviceMem mean_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * N * G); + SimpleDeviceMem dgamma_dev(sizeof(DGammaDataType) * G * C); + SimpleDeviceMem dbeta_dev(sizeof(DBetaDataType) * G * C); + + using DeviceOp = + ck::tensor_operation::device::DeviceNormalizationBwdGammaBeta; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::size_t num_bytes = sizeof(DYDataType) * length + sizeof(XDataType) * length + + sizeof(GammaDataType) * G * C + sizeof(MeanInvStdDataType) * N * G * 2 + + sizeof(DGammaDataType) * G * C + sizeof(DBetaDataType) * G * C; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, + strideDy, + strideX, + strideMeanInvStd, + strideMeanInvStd, + {G, C}, + strideDGammaBeta, + strideDGammaBeta, + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp similarity index 52% rename from client_example/18_groupnorm/groupnorm_swish.cpp rename to client_example/18_groupnorm/groupnorm_swish_fwd.cpp index e1d198d2282b85f21fb4c04afde443133c801b15..d10d16bf9d91f79a9e181777b0a4ef148d57a79f 100644 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp @@ -7,17 +7,19 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" +#include "ck/library/tensor_operation_instance/gpu/normalization_fwd_swish.hpp" -using XDataType = ck::half_t; -using GammaDataType = float; -using BetaDataType = float; -using YDataType = ck::half_t; -using ComputeDataType = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using Swish = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 5; constexpr int NumReduceDim = 3; @@ -49,22 +51,27 @@ int main(int argc, char* argv[]) std::size_t xy_size = N * H * W * G * C; std::size_t gamma_beta_size = G * C; - std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; - std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector save_mean_inv_std_strides = {G, 1}; SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); - - using DeviceOp = ck::tensor_operation::device::DeviceNormalization; +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); +#endif + + using DeviceOp = ck::tensor_operation::device::DeviceNormalizationFwd; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -75,19 +82,26 @@ int main(int argc, char* argv[]) const auto& generic_op_ptr = op_ptrs[0]; auto generic_argument_ptr = - generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims 1e-6, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif Swish{}); if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) @@ -107,21 +121,29 @@ int main(int argc, char* argv[]) for(int i = 0; i < op_ptrs.size(); ++i) { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -129,12 +151,20 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -165,25 +195,37 @@ int main(int argc, char* argv[]) std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/19_pool/CMakeLists.txt b/client_example/19_pool/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..861c1a32578e7d741c777032f7d22b63b14c098d --- /dev/null +++ b/client_example/19_pool/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) +target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp) +target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) +target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_other_operations) + +add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp) +target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/19_pool/avg_pool3d_bwd.cpp b/client_example/19_pool/avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..686d1da3ad61a5cb9ca4c8bdec034591cd354c55 --- /dev/null +++ b/client_example/19_pool/avg_pool3d_bwd.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; + +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size) + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mMemSize_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + using DeviceOp = ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, DOutLayout, DInLayout>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(DInDataType) + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp similarity index 78% rename from client_example/19_pool_fwd/avg_pool3d_fwd.cpp rename to client_example/19_pool/avg_pool3d_fwd.cpp index db8e0569d717f08e9b64e4504d40dd23aab13129..6739a41b2f1abfb4f8795fa5973358fecbfe70b9 100644 --- a/client_example/19_pool_fwd/avg_pool3d_fwd.cpp +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -94,7 +94,6 @@ int main(int argc, char* argv[]) SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); - SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); using DeviceOp = ck::tensor_operation::device::DevicePoolFwdMakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -184,21 +183,21 @@ int main(int argc, char* argv[]) std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - in_length, - window_spatial_lengths, - out_length, - in_tensor_stride, - out_tensor_stride, - out_tensor_stride, - window_strides, - window_dilations, - input_left_pads, - input_right_pads, - {2, 3, 4}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + nullptr, + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/19_pool/max_pool2d_bwd.cpp b/client_example/19_pool/max_pool2d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53ece7425f99bd6eadc6de1a6f041a12fa0951d0 --- /dev/null +++ b/client_example/19_pool/max_pool2d_bwd.cpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using IndexDataType = int32_t; + +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + // Generate index data from max pool forward + { + using MaxPoolFwdDeviceOp = + ck::tensor_operation::device::DevicePoolFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolFwdDeviceOp>::GetInstances(); + + auto& op_ptr = op_ptrs[0]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + } + + // Run MaxPool bwd + using MaxPoolBwdDeviceOp = + ck::tensor_operation::device::DeviceMaxPoolBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolBwdDeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = in_tensor_size * sizeof(DInDataType) + + out_tensor_size * sizeof(IndexDataType) + + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << "GB / s," + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool_fwd/max_pool2d_fwd.cpp b/client_example/19_pool/max_pool2d_fwd.cpp similarity index 100% rename from client_example/19_pool_fwd/max_pool2d_fwd.cpp rename to client_example/19_pool/max_pool2d_fwd.cpp diff --git a/client_example/19_pool_fwd/CMakeLists.txt b/client_example/19_pool_fwd/CMakeLists.txt deleted file mode 100644 index 13f9f73c83d55c801fdf4609e13f6b0813cb0c67..0000000000000000000000000000000000000000 --- a/client_example/19_pool_fwd/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) -target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) - -add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) -target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) \ No newline at end of file diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a60bada473b38f041bc409ef42dc4fabb7ecc57d..a3dc8537672d56fce0223e0703669cd4bc665b16 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) -target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) + target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp index 94a57cd029a30d0fc73bae8cb8e43c1d475c95bc..a740c22f91fc8edfd221b014df4dc535687f789f 100644 --- a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -191,6 +191,7 @@ int main(int argc, char* argv[]) << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; // run the best intance + if(found) { auto& op_ptr = op_ptrs[best_op_id]; diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..92e31495c2cce9643e0f6da71d05752145347552 --- /dev/null +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c758720e10f6cd222ee68ecd5b54069984279c31 --- /dev/null +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + d0_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + d0_dev_bufs.emplace_back(sizeof(D0DataType) * + f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back( + {a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + std::array{d0_dev_bufs[i].GetDeviceBuffer()}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + std::array{0}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 2); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..19c613381e294a00fbb139c6c9085d62bf07bc39 --- /dev/null +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) + +add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + +add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b16fe903875c10f78c3b7e2425185176db49575c --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..045fe47c4fdc80fa13721a1d219d35e26cc25dd9 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 16); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f82140f3fc3963eb1a5ee3b3aad67db69f96d92 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_im2col_col2im/CMakeLists.txt b/client_example/22_im2col_col2im/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d938d8cfb3e57a17361de5dd56925dcc4a09f5af --- /dev/null +++ b/client_example/22_im2col_col2im/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_image_to_column image_to_column.cpp) +target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations) + +add_executable(client_column_to_image column_to_image.cpp) +target_link_libraries(client_column_to_image PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/22_im2col_col2im/column_to_image.cpp b/client_example/22_im2col_col2im/column_to_image.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ebe63198fbc7d943763289157dd84cd26cece3b --- /dev/null +++ b/client_example/22_im2col_col2im/column_to_image.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Ho * Wo * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Hi * Wi * G * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/22_im2col_col2im/image_to_column.cpp b/client_example/22_im2col_col2im/image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8eafbdc5bbfec61b52ed3ed6dc5d7045d6d285e9 --- /dev/null +++ b/client_example/22_im2col_col2im/image_to_column.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/23_elementwise_transpose/CMakeLists.txt b/client_example/23_elementwise_transpose/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6b2421d88136a9306a79d4b5dc9777670072354e --- /dev/null +++ b/client_example/23_elementwise_transpose/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp) +target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations) diff --git a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..65ba46fcd2736c853754d86e09cccb93be66c3be --- /dev/null +++ b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/transpose_3d.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + const int N = 16; + const int C = 8; + const int D = 8; + const int H = 8; + const int W = 8; + + std::vector ncdhw = {N, C, D, H, W}; + std::vector nchwd = {N, C, H, W, D}; + auto size = N * C * D * H * W; + + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, C, H, W, D + + SimpleDeviceMem a_dev_buf(sizeof(ADataType) * size); + SimpleDeviceMem b_dev_buf(sizeof(BDataType) * size); + + std::array input = {a_dev_buf.GetDeviceBuffer()}; + std::array output = {b_dev_buf.GetDeviceBuffer()}; + + using DeviceElementwisePermuteInstance = ck::tensor_operation::device:: + DeviceElementwise, ck::Tuple, PassThrough, 5>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceElementwisePermuteInstance>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + + sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b4895db891861c501c9a9a733a1b83aeb9898288 --- /dev/null +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -0,0 +1,40 @@ +# Fwd scaleadd scaleadd relu +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd scaleadd AB +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd bilinear +add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 + grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd data bilinear +add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 + grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb106e8d8eef32e9034d2681fc471b8d6794bde3 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_bilinear() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple, + InDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32ab481319449d473f933b1862e79c7f25cdeda1 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_bilinear() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc new file mode 100644 index 0000000000000000000000000000000000000000..3f6f7b07737cecc4122c9dea62c87b19aafd5400 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_ab.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scaleadd_ab() +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + + constexpr float scale = 1.5f; + + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + using InputDtype = ck::tuple_element_t<0, InDataType>; + using InputBiasDtype = ck::tuple_element_t<1, InDataType>; + using WeightDtype = ck::tuple_element_t<0, WeiDataType>; + using WeightBiasDtype = ck::tuple_element_t<1, WeiDataType>; + + SimpleDeviceMem in(sizeof(InputDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem in_bias(sizeof(InputBiasDtype) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeightDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem wei_bias(sizeof(WeightBiasDtype) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + ScaleAdd, + ScaleAdd, + PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::array as = {in.GetDeviceBuffer(), in_bias.GetDeviceBuffer()}; + std::array bs = {wei.GetDeviceBuffer(), wei_bias.GetDeviceBuffer()}; + std::array ds{}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Z * Y * X + + N * Di * Hi * Wi * G * C + G * K * Z * Y * X * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * N * Di * Hi * Wi * G * C + + 2 * sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * N * Do * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(as, + bs, + ds, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + ScaleAdd{scale}, + ScaleAdd{scale}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fef3f7428c2149cd11892a987c320eeba93135cc --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::bhalf_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43db279191331edf969bdc15e526f4694c1650ec --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = ck::half_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cccec47701fbf52b9bc9ab2f37c49e29387224f9 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = float; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..28674c8abeb702211b699ca5d1c864837bb54990 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::Tuple; +using WeiDataType = ck::Tuple; +using OutDataType = int8_t; + +#include "grouped_conv_fwd_scaleadd_ab.inc" + +int main() { return execute_conv_fwd_scaleadd_ab(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc new file mode 100644 index 0000000000000000000000000000000000000000..4e3cf69637c55178c4a7007891bd918c1cacf5df --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scaleadd_scaleadd_relu.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAddScaleAddRelu = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scaleadd_scaleadd_relu() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d0(sizeof(std::tuple_element_t<0, DDataTypes>) * N * Do * Ho * Wo * G * K); + SimpleDeviceMem d1(sizeof(std::tuple_element_t<1, DDataTypes>) * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, std::tuple_element_t<1, DDataTypes>>, + OutDataType, + PassThrough, + PassThrough, + ScaleAddScaleAddRelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {d0.GetDeviceBuffer(), d1.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ScaleAddScaleAddRelu{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 2 * N * Ho * Wo * G * K; + std::size_t num_bytes = + sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C + + (sizeof(OutDataType) + sizeof(std::tuple_element_t<0, DDataTypes>) + + sizeof(std::tuple_element_t<1, DDataTypes>)) * + N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {d0.GetDeviceBuffer(), d1.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths, bias_lengths}, + {out_strides, bias_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + ScaleAddScaleAddRelu{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a32c4f74277250de93f2c06d6419ccc39b13514 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3e91072b34104c85aab41363e06d34739dd4831 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7ed96b6a04a67b85e366e2a1405561ae65df925 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9959664d2a110483f26340e25224cee16997c5e9 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +#include "grouped_conv_fwd_scaleadd_scaleadd_relu.inc" + +int main() { return execute_conv_fwd_scaleadd_scaleadd_relu(); } diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fdfc1d8d2ec17fcd183c2d22f64977a9ba2ec245 --- /dev/null +++ b/client_example/25_wrapper/CMakeLists.txt @@ -0,0 +1,12 @@ +add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp) +target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) +add_executable(client_wrapper_img2col wrapper_img2col.cpp) +target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR + GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR + GPU_TARGETS MATCHES "gfx942") + add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) + target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) + add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) + target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations) +endif() diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eba3de017f41bf7cc7de8a4f2cd581dfd81c0093 --- /dev/null +++ b/client_example/25_wrapper/README.md @@ -0,0 +1,177 @@ +# Composable Kernel wrapper GEMM tutorial + +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) +wrapper. We present the base version of GEMM without most of the available optimizations; however, +it's worth noting that CK has kernels with different optimizations. + +To implement these optimizations, you can use the CK wrapper or directly use available instances in +CK. You can also refer to the +[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), +that uses CK wrapper based on the +[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. + +The kernel definition should look similar to: + +```cpp +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +``` + +We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass +selected lengths of processed data through each block (`tile_shape`) and thread layout +(`thread_layout`). For compilation time parameters, we define the data type, +[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp) +and scalar per vector value during copy. + +Step 1: Create layouts for global and LDS memory. + +```cpp + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); +``` + +We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or +`KPerBlock`. + +Step 2: Create tensors for global and LDS memory. + +```cpp + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + + // Allocate LDS memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); +``` + +We must specify parameters for copy and convert block indexes to tuple: + +```cpp + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; +``` + +We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also +define and clear an output register (`c_vgpr_reg`) for the accumulation. + +```cpp + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); +``` + +We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and +`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output +and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only +generic functions for the CK wrapper. + +Step 3: Create the compute loop. + +```cpp + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to LDS. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise GEMM. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); +``` + +Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block), +data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM +operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`. + +The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread): + +```cpp + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +``` + +If you want to dive deep into the details, you can find the entire example +[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp). diff --git a/client_example/25_wrapper/tensor_transform_using_wrapper.cpp b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4b25d85e2d34894f8acade03afc07cd7fa85892f --- /dev/null +++ b/client_example/25_wrapper/tensor_transform_using_wrapper.cpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +#include "ck/wrapper/layout.hpp" + +using DataType = int; + +template +void Print1d(const Layout& layout) +{ + std::cout << "Print1d" << std::endl; + for(ck::index_t w = 0; w < ck::wrapper::size(layout); w++) + { + std::cout << layout(ck::make_tuple(w)) << " "; + } + std::cout << std::endl; +} + +template +void Print2d(const Layout& layout) +{ + std::cout << "Print2d" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(layout); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(h, w)) << " "; + } + std::cout << std::endl; + } +} + +// Print in (x,y),z pattern +template +void Print3dCustom(const Layout& layout) +{ + std::cout << "Print3dCustom" << std::endl; + for(ck::index_t d = 0; d < ck::wrapper::size<0>(ck::wrapper::get<0>(layout)); d++) + { + for(ck::index_t h = 0; h < ck::wrapper::size<1>(ck::wrapper::get<0>(layout)); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(layout); w++) + { + std::cout << layout(ck::make_tuple(ck::make_tuple(d, h), w)) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + } +} + +int main() +{ + // Layout traverse in row-major + std::cout << "Note: Layout traverse in column-major" << std::endl; + // Basic descriptor 0, 1, 2, ... 30, 31 (compile-time descriptor) + // (dims:4,8 strides:1,4) + const auto shape_4x8 = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}); + const auto layout_4x8_s1x4 = ck::wrapper::make_layout(shape_4x8); + std::cout << "dims:4,8 strides:1,4" << std::endl; + Print2d(layout_4x8_s1x4); + using Cord1x1Type = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t offset_1x1 = layout_4x8_s1x4.template operator()(); + std::cout << "Constexpr calculated [1, 1] offset:" << offset_1x1 << std::endl; + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (runtime descriptor) + // dims:4,(2,4) strides:2,(1,8) + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout_4x2x4_s2x1x8 = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::cout << "dims:4,(2,4) strides:2,(1,8)" << std::endl; + Print2d(layout_4x2x4_s2x1x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:(2,2),(2,4) strides:((1,4),(2,8) + const auto shape_2x2x2x4 = ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<4>{})); + const auto strides_s1x4x2x8 = ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<2>{}, ck::Number<8>{})); + static const auto layout_2x2x2x4_s1x4x2x8 = + ck::wrapper::make_layout(shape_2x2x2x4, strides_s1x4x2x8); + + std::cout << "dims:(2,2),(2,4) strides:(1,4),(2,8)" << std::endl; + Print2d(layout_2x2x2x4_s1x4x2x8); + Print3dCustom(layout_2x2x2x4_s1x4x2x8); + + // Basic descriptor 0, 1, 8, 9, 16, 17, ... 30, 31 (compile-time descriptor) + // dims:((2,2),2),4 strides:((1,4),2),8 + // Transform to 2d + const auto shape_2x2x2x4_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<2>{}, ck::Number<2>{}), ck::Number<2>{}), + ck::Number<4>{}); + const auto strides_s1x4x2x8_nested = ck::make_tuple( + ck::make_tuple(ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}), ck::Number<2>{}), + ck::Number<8>{}); + static const auto layout_2x2x2x4_s1x4x2x8_nested = + ck::wrapper::make_layout(shape_2x2x2x4_nested, strides_s1x4x2x8_nested); + + std::cout << "dims:((2,2),2),4 strides:((1,4),2),8" << std::endl; + Print1d(layout_2x2x2x4_s1x4x2x8_nested); + Print2d(layout_2x2x2x4_s1x4x2x8_nested); + Print3dCustom(layout_2x2x2x4_s1x4x2x8_nested); + + return 0; +} diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1f1a4de75121df8fddb2744ddf2b905d9beac787 --- /dev/null +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + // Allocate lds memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + // Create tile and partition for C. Use specific function for blockwise_gemm to assign the + // appropriate partitions. + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); + + // Iterate over K with KPerBlock step. + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to lds. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + // Copy vgpr results to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2a4034d62fad4217a438e439de8659002433825d --- /dev/null +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +static constexpr ck::index_t NumDimSpatial = 3; +using DataType = float; +using InputLayout = ck::tensor_layout::convolution::NDHWGC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ +DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + // grid layout (dim1, dim0) + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.y), static_cast(blockIdx.x)); + + // Get local tiles for global memory + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); + + // Get partition per thread + const auto input_local_partition = + ck::wrapper::make_local_partition(input_local_tile, thread_layout, threadIdx.x); + auto output_local_partition = + ck::wrapper::make_local_partition(output_local_tile, thread_layout, threadIdx.x); + + // Perform copy + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + constexpr ck::index_t scalar_per_vector = 4; + ck::wrapper::copy(input_local_partition, + output_local_partition); +} + +void PerformImageToColumnPad0(const ck::index_t G, + const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t C, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + std::array filter_strides, + std::array filter_dilations) +{ + const ck::index_t ZYXC = Z * Y * X * C; + const ck::index_t GC = G * C; + + // shape: (G, (Wo, Ho, Do, N)), (C, X, Y, Z)) + const auto shape = ck::make_tuple(ck::make_tuple(G, ck::make_tuple(Wo, Ho, Do, N)), + ck::make_tuple(C, X, Y, Z)); + const auto in_strides = + ck::make_tuple(ck::make_tuple(C, + ck::make_tuple(filter_strides[2] * GC, + filter_strides[1] * Wi * GC, + filter_strides[0] * Hi * Wi * GC, + Di * Hi * Wi * GC)), + ck::make_tuple(1, + filter_dilations[2] * GC, + filter_dilations[1] * Wi * GC, + filter_dilations[0] * Hi * Wi * GC)); + const auto in_layout = ck::wrapper::make_layout(shape, in_strides); + + const auto out_strides = ck::make_tuple( + ck::make_tuple( + ZYXC, + ck::make_tuple(ZYXC * G, Wo * ZYXC * G, Ho * Wo * ZYXC * G, Do * Ho * Wo * ZYXC * G)), + ck::make_tuple(1, C, X * C, Y * X * C)); + const auto out_layout = ck::wrapper::make_layout(shape, out_strides); + + const ck::index_t input_size = N * Di * Hi * Wi * GC; + // Global memory buffers + SimpleDeviceMem in_buf(input_size * sizeof(DataType)); + SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); + + // User can choose appropriate number of threads and sizes per block + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}), + ck::make_tuple(ck::Number<16>{}, ck::Number<1>{})); + // This example doesn't support padding, user should select tile sizes + // which are divisible by the shape. + const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); + + // Create buffers for global memory + auto input_tensor_global = ck::wrapper::make_tensor( + static_cast(in_buf.GetDeviceBuffer()), in_layout); + auto output_tensor_global = ck::wrapper::make_tensor( + static_cast(out_buf.GetDeviceBuffer()), out_layout); + + // grid layout (dim1, dim0) + const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)); + + const auto kernel = DeviceImageToColumnPad0; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + input_tensor_global, + output_tensor_global, + tile_shape, + thread_layout); + + std::size_t num_btype = G * N * Do * Ho * Wo * ZYXC * 2 * sizeof(DataType); + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << std::endl; +} + +int main(int argc, char* argv[]) +{ + constexpr ck::index_t G = 4; // number of groups + constexpr ck::index_t N = 32; // batch + constexpr ck::index_t C = 64; // input channel (per group) + constexpr ck::index_t Z = 3; // filter D + constexpr ck::index_t Y = 3; // filter H + constexpr ck::index_t X = 3; // filter W + constexpr ck::index_t Di = 9; // input D + constexpr ck::index_t Hi = 9; // input H + constexpr ck::index_t Wi = 7; // input W + constexpr ck::index_t Do = 7; // output D + constexpr ck::index_t Ho = 7; // output H + constexpr ck::index_t Wo = 5; // output W + PerformImageToColumnPad0(G, + N, + Di, + Hi, + Wi, + Do, + Ho, + Wo, + C, + Z, + Y, + X, + {1, 1, 1} /*filter_strides*/, + {1, 1, 1} /*filter_dilations*/); + return 0; +} +// MI100 Perf: 0.255178 ms, 1698.9 GB/s, diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ddf01de612a61daa4d3d2d3fffb30662656573d7 --- /dev/null +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + // Create layouts for global memory + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Apply padding + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + // Create tensors for global memory + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + // Create layouts and tensors for lds memory. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + // Create tile and partition for C global memory. Use specific gemm + // functions to get appropriate layouts. + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Define and clear c vgpr register + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + // Local partitions for lds memory + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + // Lamda to slice tensor, then create local tile and partition + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + // Copy first values to lds + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + // Pipeline loop + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + // Skip if only tile should be processed + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + // Copy data to A vgpr. + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + // Synchronize. + ck::block_sync_lds(); + // Copy data to B vgpr. + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + // Perform gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Synchronize + ck::block_sync_lds(); + // Copy data to A and B lds tiles. + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + // Handle tail. + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Store data from C vgpr to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s, diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index eb793b3cbd6e693fd8cf94ef3489c6ac3eb915c7..753f5e5ae518cd3310ced0f73a4241a45793b3d4 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,7 @@ else() endif() endif() -find_package(composable_kernel COMPONENTS device_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index 01b348c4583f5d18b202410a694ee6a8e460580a..cf77991a64a7543ac7b97d604ea3b9e72a626c62 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -149,7 +149,7 @@ function(clang_tidy_check TARGET) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy DEPENDS ${SOURCE} - COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" + COMMAND ${CLANG_TIDY_COMMAND} "-config=\{CheckOptions: \[\{key: bugprone-reserved-identifier.AllowedIdentifiers,value: __HIP_PLATFORM_HCC__\; __HIP_PLATFORM_AMD__\; __HIP_ROCclr__\}\]\}" ${SOURCE} "-export-fixes=${CLANG_TIDY_FIXIT_DIR}/${TARGET}-${tidy_file}.yaml" WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMENT "clang-tidy: Running clang-tidy on target ${SOURCE}..." ) diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake index 2e3669fcdf48fde87c23a84c2493ce26aecda7af..c91308b5bb6db3cac55628c2bf5123f261cc838d 100644 --- a/cmake/DoxygenDoc.cmake +++ b/cmake/DoxygenDoc.cmake @@ -309,6 +309,8 @@ XML_OUTPUT XML_PROGRAMLISTING ) +set(WARN_AS_ERROR YES) + set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") function(add_doxygen_doc) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 66139cc710992df07151b85537c28f483e2ef452..87cb8cdf1180e81c7626a0083d144b3a9c38ba48 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -70,6 +70,7 @@ else() -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt + -Wno-unused-template ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") list(APPEND CMAKE_COMPILER_WARNINGS diff --git a/cmake/getopt.cmake b/cmake/getopt.cmake new file mode 100644 index 0000000000000000000000000000000000000000..dd985ff472287797a69550d785459251e01d62e5 --- /dev/null +++ b/cmake/getopt.cmake @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +add_library(getopt::getopt INTERFACE IMPORTED GLOBAL) + +if(WIN32) + include(FetchContent) + + FetchContent_Declare( + getopt + GIT_REPOSITORY https://github.com/apwojcik/getopt.git + GIT_TAG main + SYSTEM + ) + + set(__build_shared_libs ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + + FetchContent_MakeAvailable(getopt) + + # Restore the old value of BUILD_SHARED_LIBS + set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + + FetchContent_GetProperties(getopt) + + target_link_libraries(getopt::getopt INTERFACE wingetopt) + target_include_directories(getopt::getopt INTERFACE ${getopt_SOURCE_DIR}/src) +endif() \ No newline at end of file diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake deleted file mode 100644 index d6577ac33e71c575b2aa441da30838f58157f0b4..0000000000000000000000000000000000000000 --- a/cmake/googletest.cmake +++ /dev/null @@ -1,50 +0,0 @@ -include(FetchContent) - -set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") - -if(GOOGLETEST_DIR) - set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") -endif() - -message(STATUS "Fetching GoogleTest") - -list(APPEND GTEST_CMAKE_CXX_FLAGS - -Wno-undef - -Wno-reserved-identifier - -Wno-global-constructors - -Wno-missing-noreturn - -Wno-disabled-macro-expansion - -Wno-used-but-marked-unused - -Wno-switch-enum - -Wno-zero-as-null-pointer-constant - -Wno-unused-member-function - -Wno-comma - -Wno-old-style-cast - -Wno-deprecated - -Wno-unsafe-buffer-usage -) -message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") - -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG b85864c64758dec007208e56af933fc3f52044ee -) - -# Will be necessary for windows build -# set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -FetchContent_GetProperties(googletest) -if(NOT googletest_POPULATED) - FetchContent_Populate(googletest) - add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) -endif() - -target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) -target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) - -set_target_properties(gtest PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gtest_main PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(gmock_main PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0915f534112f441fbf77a93dbcdb58f2bb7f141c --- /dev/null +++ b/cmake/gtest.cmake @@ -0,0 +1,70 @@ +include(FetchContent) + +set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") + +if(GOOGLETEST_DIR) + set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") +endif() + +FetchContent_Declare( + GTest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 +) + +# Suppress ROCMChecks WARNING on GoogleTests +set(ROCM_DISABLE_CHECKS FALSE) +macro(rocm_check_toolchain_var var access value list_file) + if(NOT ROCM_DISABLE_CHECKS) + _rocm_check_toolchain_var("${var}" "${access}" "${value}" "${list_file}") + endif() +endmacro() + +if(WIN32) + set(gtest_force_shared_crt ON CACHE_INTERNAL "") +endif() + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +# Store the current value of BUILD_SHARED_LIBS +set(__build_shared_libs ${BUILD_SHARED_LIBS}) +set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") + +set(ROCM_DISABLE_CHECKS TRUE) +FetchContent_MakeAvailable(GTest) +set(ROCM_DISABLE_CHECKS FALSE) + +# Restore the old value of BUILD_SHARED_LIBS +set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) + +set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(INSTALL_GTEST OFF CACHE INTERNAL "") + +set(GTEST_CXX_FLAGS + -Wno-undef + -Wno-reserved-identifier + -Wno-global-constructors + -Wno-missing-noreturn + -Wno-disabled-macro-expansion + -Wno-used-but-marked-unused + -Wno-switch-enum + -Wno-zero-as-null-pointer-constant + -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast + -Wno-deprecated + -Wno-unsafe-buffer-usage + -Wno-float-equal +) + +if(WIN32) + list(APPEND GTEST_CXX_FLAGS + -Wno-suggest-destructor-override + -Wno-suggest-override + -Wno-nonportable-system-include-path + -Wno-language-extension-token) +endif() + +target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) diff --git a/dev-requirements.txt b/dev-requirements.txt index 9e7b9f01e1cae3901bd0f1661f39bc383300ebd6..ca883c19e1f527e094fd8e83bffe962a8de77be5 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ -ROCmSoftwarePlatform/rocm-recipes -RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build -danmar/cppcheck@2.9 \ No newline at end of file +ROCm/rocm-recipes +ROCm/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build +danmar/cppcheck@2.9 diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst index f21d43c5939ab07d3e9b6cc2c6c361a782b13ff1..22222b0cf0c451ef25208f28cc087bd5d9ec3d0b 100644 --- a/docs/API_Reference_Guide.rst +++ b/docs/API_Reference_Guide.rst @@ -1,11 +1,13 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -******************* -API Reference Guide -******************* +.. _api-reference: + +******************************************************************** +API reference guide +******************************************************************** -================= -Introduction -================= This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design principles that are used to write new classes that extend CK functionality. @@ -30,7 +32,7 @@ DeviceMem Kernels For Flashattention --------------------------- -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists the classes that are used in the CK GPU implementation of Flashattention. **Gridwise classes** diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b2ddff398ce8efcd0c5d8be275d9dd09315d4349..3788ba609c833084bda41819b1109f0c2b3dbc2e 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -1,8 +1,105 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _contributing-to: + +******************************************************************** +Contributor's guide +******************************************************************** + +This chapter explains the rules for contributing to the Composable Kernel project, and how to contribute. + +Getting started +=============== + +#. **Documentation:** Before contributing to the library, familiarize yourself with the + `Composable Kernel User Guide `_. + It provides insight into the core concepts, environment configuration, and steps to obtain or + build the library. You can also find some of this information in the + `README file `_ + on the project's GitHub page. +#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code `_ provides a deeper understanding of the CK library and showcases its performance capabilities. + `_ + from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities. +#. **General information:** For broader information about AMD products, consider exploring the + `AMD Developer Central portal `_. + +How to contribute =================== -Contributor's Guide -=================== -Pull-request guidelines -======================= +You can make an impact by reporting issues or proposing code enhancements through pull requests. + +Reporting issues +---------------- + +Use `Github issues `_ +to track public bugs and enhancement requests. + +If you encounter an issue with the library, please check if the problem has already been +reported by searching existing issues on GitHub. If your issue seems unique, please submit a new +issue. All reported issues must include: + +* A comprehensive description of the problem, including: + + * What did you observe? + * Why do you think it is a bug (if it seems like one)? + * What did you expect to happen? What would indicate the resolution of the problem? + * Are there any known workarounds? + +* Your configuration details, including: + + * Which GPU are you using? + * Which OS version are you on? + * Which ROCm version are you using? + * Are you using a Docker image? If so, which one? + +* Steps to reproduce the issue, including: + + * What actions trigger the issue? What are the reproduction steps? + + * If you build the library from scratch, what CMake command did you use? + + * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? + +Before submitting any issue, ensure you have addressed all relevant questions from the checklist. + +Creating Pull Requests +---------------------- + +You can submit `Pull Requests (PR) on GitHub +`_. + +All contributors are required to develop their changes on a separate branch and then create a +pull request to merge their changes into the `develop` branch, which is the default +development branch in the Composable Kernel project. All external contributors must use their own +forks of the project to develop their changes. + +When submitting a Pull Request you should: + +* Describe the change providing information about the motivation for the change and a general + description of all code modifications. + +* Verify and test the change: + + * Run any relevant existing tests. + * Write new tests if added functionality is not covered by current tests. + +* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in + the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We + highly recommend contributors utilize this method to maintain consistent code formatting. + Instructions on setting up `pre-commit` can be found in the project's + `README file `_ + +* Link your PR to any related issues: + + * If there is an issue that is resolved by your change, please provide a link to the issue in + the description of your pull request. + +* For larger contributions, structure your change into a sequence of smaller, focused commits, each + addressing a particular aspect or fix. + +Following the above guidelines ensures a seamless review process and faster assistance from our +end. -[TODO] +Thank you for your commitment to enhancing the Composable Kernel project! diff --git a/docs/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst index 3462283d90d62f1a181a1cc26d04a36f9d6f4fff..e24acf56562a7d6f18a525dad6f63f2ef63eb9c0 100644 --- a/docs/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -1,16 +1,20 @@ -========================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _supported-primitives: + +******************************************************************** Supported Primitives Guide -========================== +******************************************************************** -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the -API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins -the algorithms implemented in CK. +This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. ------------ Softmax ------------ -For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the +For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, .. math:: @@ -27,7 +31,7 @@ where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size -:math:`B_r \times B_c` we can compute the row-wise softmax as follows. +:math:`B_r \times B_c` you can compute the row-wise softmax as follows. For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate, diff --git a/docs/conf.py b/docs/conf.py index 0de590da1a53daba8dc0e379e68abbae250e0667..e441ff1cedb2e373f0f7e6a40ba93344ab498b74 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,23 +4,34 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import subprocess +import re from rocm_docs import ROCmDocs +html_theme_options = {"flavor": "list"} -name = "Composable Kernel" -get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' -version = subprocess.getoutput(get_version) -if len(version) > 0: - name = f"{name} {version}" +with open('../CMakeLists.txt', encoding='utf-8') as f: + match = re.search(r'.*set\(version ([0-9.]+)[^0-9.]+', f.read()) + if not match: + raise ValueError("VERSION not found!") + version_number = match[1] +left_nav_title = f"Composable Kernel {version_number} Documentation" + +# for PDF output on Read the Docs +project = "Composable Kernel Documentation" +author = "Advanced Micro Devices, Inc." +copyright = "Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved." +version = version_number +release = version_number external_toc_path = "./sphinx/_toc.yml" -docs_core = ROCmDocs(f"{name} Documentation") -docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") +docs_core = ROCmDocs(left_nav_title) +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") docs_core.setup() +external_projects_current_project = "composable_kernel" + mathjax3_config = { 'tex': { 'macros': { diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index 66ec91096e01b9ebfbe78559b9a168711a7b4206..21121f1b826a1c68e0ae53ac3236e387f1409551 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -1,28 +1,50 @@ -=================== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _docker-hub: + +******************************************************************** CK Docker Hub -=================== +******************************************************************** -------------------------------------- Why do I need this? -------------------------------------- +=================== + +To make things simpler, and bring Composable Kernel and its dependencies together, +docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. -To make our lives easier and bring Composable Kernel dependencies together, we recommend using -docker images that can be found on `Docker Hub `_. +Refer to `Docker Overview `_ for more information on Docker images and containers. -------------------------------------- -So what is Composable Kernel? -------------------------------------- +Which image is right for me? +============================ + +The image naming includes information related to the docker image. +For example ``ck_ub20.04_rocm6.0`` indicates the following: + +* ``ck`` - made for running Composable Kernel; +* ``ub20.04`` - based on Ubuntu 20.04; +* ``rocm6.0`` - ROCm platform version 6.0. -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. +Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: -To get the CK library:: + docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git +What is inside the image? +------------------------- -run a docker container:: +The docker images have everything you need for running CK including: + +* `ROCm `_ +* `CMake `_ +* `Compiler `_ +* `Composable Kernel library `_ + +Running the docker container +============================ + +After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: docker run \ -it \ @@ -30,70 +52,50 @@ run a docker container:: --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ + rocm/composable_kernel:ck_ub20.04_rocm6.0 \ /bin/bash -and build the CK:: +After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: - mkdir build && cd build - # Need to specify target ID, example below is for gfx908 and gfx90a - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_TARGETS="gfx908;gfx90a" \ - .. +.. note:: -and:: + If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. - make -j examples tests +Stop and restart the docker image +================================= -To run all the test cases including tests and examples run:: +After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. - make test +Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. -We can also run specific examples or tests like:: +To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: - ./bin/example_gemm_xdl_fp16 - ./bin/test_gemm_fp16 + docker exec -it bash -For more details visit `CK github repository `_, -`CK examples `_, -`even more CK examples `_. +Where: -------------------------------------- -And what is inside? -------------------------------------- +* `exec` is the docker command +* `-it` is the interactive option for `exec` +* `` specifies an active container on the system +* `bash` specifies the command to run in the interactive shell -The docker images have everything you need for running CK including: +.. note:: -* `ROCm `_ -* `CMake `_ -* `Compiler `_ + You can use the ``docker container ls`` command to list the active containers on the system. -------------------------------------- -Which image is right for me? -------------------------------------- - -Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are: - -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm5.6`` - ROCm platform version 5.6. +To start a container from the image, use the ``docker start`` command:: -So just pick the right image for your project dependencies and you're all set. + docker start -------------------------------------- -DIY starts here -------------------------------------- +Then use the docker exec command as shown above to start the bash shell. -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the -`Dockerfile `_ -for your needs. +Use the ``docker stop`` command to stop the container and restore the image to its initial state:: -------------------------------------- -License -------------------------------------- + docker stop + +Editing the docker image +======================= -CK is released under the MIT `license `_. +If you want to customize the docker image, edit the +`Dockerfile `_ +from the GitHub repository to suit your needs. diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index 1084f94c81bbef88a3373d6e9e5d24b89e7b5353..fac9e138e1bc59a8ea37afceebbdedb9ff03e084 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -58,7 +58,7 @@ PROJECT_LOGO = # entered, it will be relative to the location where doxygen was started. If # left blank the current directory will be used. -OUTPUT_DIRECTORY = docBin +OUTPUT_DIRECTORY = . # If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- # directories (in 2 levels) under the output directory of each output format and @@ -778,7 +778,9 @@ WARN_LOGFILE = INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ - ../../library/include/ck/library/utility + ../../library/include/ck/library/utility \ + ../../include/ck/wrapper + # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses diff --git a/docs/index.rst b/docs/index.rst index 51c0c862ae3e3e8b314aa37bd63b6d762b895607..8ae4ce3a22fa9e69e49f1ede82447fa54de4be26 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,55 +1,39 @@ -============================ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _composable-kernel: + +******************************************************************** Composable Kernel User Guide -============================ +******************************************************************** + +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. ------------- -Introduction ------------- +The CK documentation is structured as follows: -This document contains instructions for installing, using, and contributing to Composable Kernel (CK). +.. card:: Conceptual ------------ -Methodology ------------ + * :ref:`what-is-ck` -Composable Kernel (CK) library aims to provide a programming model for writing performance critical -kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, -through general purpose kernel languages, like HIP C++. +.. card:: Installation -CK utilizes two concepts to achieve performance portability and code maintainability: + * :ref:`docker-hub` -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call - "Tensor Coordinate Transformation". +.. card:: Tutorial -.. image:: data/ck_component.png - :alt: CK Components + * :ref:`hello-world` --------------- -Code Structure --------------- +.. card:: API reference -Current CK library are structured into 4 layers: + * :ref:`supported-primitives` + * :ref:`api-reference` + * :ref:`wrapper` -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer +.. card:: Contributing to CK -.. image:: data/ck_layer.png - :alt: CK Layers - -Documentation Roadmap -^^^^^^^^^^^^^^^^^^^^^ -The following is a list of CK documents in the suggested reading order: + * :ref:`contributing-to` -.. toctree:: - :maxdepth: 5 - :caption: Contents: - :numbered: +To contribute to the documentation refer to `Contributing to ROCm `_. - tutorial_hello_world - dockerhub - Supported_Primitives_Guide - API_Reference_Guide - Contributors_Guide +You can find licensing information on the `Licensing `_ page. diff --git a/docs/license.md b/docs/license.md new file mode 100644 index 0000000000000000000000000000000000000000..43e471da0e49af1156dc998033d224d9eb172ab3 --- /dev/null +++ b/docs/license.md @@ -0,0 +1,2 @@ +```{include} ../LICENSE.md +``` diff --git a/docs/license.rst b/docs/license.rst deleted file mode 100644 index ddb544496e41b92dcaadb6ff816946e46ca16da8..0000000000000000000000000000000000000000 --- a/docs/license.rst +++ /dev/null @@ -1,6 +0,0 @@ -======= -License -======= - -.. include:: ../LICENSE - :literal: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 83dd1e7b1a5d759143f2db13a1e296f5e66da96d..5780674624ad9e2cdabc3c11906e0cbb8577c1e5 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -1,10 +1,21 @@ -# Anywhere {branch} is used, the branch name will be substituted. -# These comments will also be removed. defaults: numbered: False - maxdepth: 6 root: index subtrees: - - caption: About - entries: - - file: license +- entries: + - file: what-is-ck.rst + title: What is Composable Kernel? + - file: dockerhub.rst + title: Docker Hub + - file: tutorial_hello_world.rst + title: Hello World Tutorial + - file: Supported_Primitives_Guide.rst + title: Supported Primitives + - file: API_Reference_Guide.rst + title: API Reference + - file: wrapper.rst + title: Wrapper + - file: Contributors_Guide.rst + title: Contributing to CK + - file: license.md + title: License diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index a12a00a2b21ec12d8d6f41d0838309350d2eefd1..ae2cbe44ab6fcb85b47403c41e1cc206e13dcd94 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core>=0.20.0 -sphinxcontrib-bibtex==2.5.0 +rocm-docs-core==0.34.2 +sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index d5f67eeb585b3b457ee49bf5b1b46468f76e742a..84f232fa2dbf90be368fcdd4f3e5bb46392861ea 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -16,7 +16,7 @@ beautifulsoup4==4.11.2 # via pydata-sphinx-theme breathe==4.34.0 # via rocm-docs-core -certifi==2022.12.7 +certifi==2023.7.22 # via requests cffi==1.15.1 # via @@ -26,7 +26,7 @@ charset-normalizer==3.1.0 # via requests click==8.1.3 # via sphinx-external-toc -cryptography==40.0.2 +cryptography==41.0.6 # via pyjwt deprecated==1.2.13 # via pygithub @@ -42,12 +42,18 @@ fastjsonschema==2.18.0 # via rocm-docs-core gitdb==4.0.10 # via gitpython -gitpython==3.1.31 +gitpython==3.1.37 # via rocm-docs-core idna==3.4 # via requests imagesize==1.4.1 # via sphinx +importlib-metadata==6.8.0 + # via + # sphinx + # sphinxcontrib-bibtex +importlib-resources==6.1.0 + # via rocm-docs-core jinja2==3.1.2 # via # myst-parser @@ -82,28 +88,32 @@ pydata-sphinx-theme==0.13.3 # via # rocm-docs-core # sphinx-book-theme -pygithub==1.58.2 +pygithub==1.58.1 # via rocm-docs-core -pygments==2.14.0 +pygments==2.15.0 # via # accessible-pygments # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.6.0 - # via pygithub + # via + # pygithub + # pyjwt pynacl==1.5.0 # via pygithub +pytz==2023.3.post1 + # via babel pyyaml==6.0 # via # myst-parser # pybtex # rocm-docs-core # sphinx-external-toc -requests==2.28.2 +requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core>=0.20.0 +rocm-docs-core==0.34.2 # via -r requirements.in six==1.16.0 # via @@ -131,7 +141,7 @@ sphinx-book-theme==1.0.1 # via rocm-docs-core sphinx-copybutton==0.5.1 # via rocm-docs-core -sphinx-design==0.3.0 +sphinx-design==0.4.1 # via rocm-docs-core sphinx-external-toc==0.3.1 # via rocm-docs-core @@ -139,7 +149,7 @@ sphinx-notfound-page==0.8.3 # via rocm-docs-core sphinxcontrib-applehelp==1.0.4 # via sphinx -sphinxcontrib-bibtex==2.5.0 +sphinxcontrib-bibtex==2.6.2 # via -r requirements.in sphinxcontrib-devhelp==1.0.2 # via sphinx @@ -153,7 +163,11 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx typing-extensions==4.5.0 # via pydata-sphinx-theme -urllib3==1.26.15 +urllib3==1.26.18 # via requests wrapt==1.15.0 # via deprecated +zipp==3.17.0 + # via + # importlib-metadata + # importlib-resources diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index bfb197e085a2a986ae50aea6b6eeebd8d203eab1..c31460785bd127b92a4b2f53551541b041919102 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -1,52 +1,44 @@ -=============== -CK Hello world -=============== +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation -------------------------------------- -Motivation -------------------------------------- +.. _hello-world: -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who -would like to optimize their pipelines and squeeze every performance drop by adding Composable -Kernel (CK) library to their projects. We would like to make the CK library approachable so -the tutorial is not based on the latest release and doesn't have all the bleeding edge features, -but it will be reproducible now and forever. +******************************************************************** +Hello World Tutorial +******************************************************************** -During this tutorial we will have an introduction to the CK library, we will build it and run some -examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go -in depth and breadth and get familiar with other tools and ways to integrate CK into your project. +This tutorial is for engineers dealing with artificial intelligence and machine learning who +would like to optimize pipelines and improve performance using the Composable +Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. -------------------------------------- Description -------------------------------------- +=========== -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and -efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast -and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create -new ones. The library has components required for majority of modern neural networks architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, variety of -activation functions, fused operators and many more. +Modern AI technology solves more and more problems in a variety of fields, but crafting fast and +efficient workflows is still challenging. CK can make the AI workflow fast +and efficient. CK is a collection of optimized AI operator kernels with tools to create +new kernels. The library has components required for modern neural network architectures +including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. -So how do we (almost) reach the speed of light? CK acceleration abilities are based on: +CK library acceleration features are based on: -* Layered structure. -* Tile-based computation model. -* Tensor coordinate transformation. -* Hardware acceleration use. -* Support of low precision data types including fp16, bf16, int8 and int4. +* Layered structure +* Tile-based computation model +* Tensor coordinate transformation +* Hardware acceleration use +* Support of low precision data types including fp16, bf16, int8 and int4 -If you are excited and need more technical details and benchmarking results - read this awesome +If you need more technical details and benchmarking results read the following `blog post `_. -For more details visit our `github repository `_. +To download the library visit the `composable_kernel repository `_. -------------------------------------- Hardware targets -------------------------------------- +================ -CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are -supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target -GPU architecture. +CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are +supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. ========== ========= GPU Target AMD GPU @@ -59,47 +51,24 @@ gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6 There are also `cloud options `_ you can find if you don't have an AMD GPU at hand. -------------------------------------- Build the library -------------------------------------- +================= -First let's clone the library and rebase to the tested version:: +This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. - git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - cd composable_kernel/ - git checkout tutorial_hello_world - -To make our lives easier we prepared -`docker images `_ with all the necessary -dependencies. Pick the right image and create a container. In this tutorial we use -``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and -ROCm v5.6. - -If your current folder is ``${HOME}``, start the docker container with:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.6 \ - /bin/bash +.. note:: -If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` -to fit your folder structure. + You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. -Inside the docker container current folder is ``~/workspace``, library path is -``~/workspace/composable_kernel``, navigate to the library:: +Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: cd composable_kernel/ -Create and go to the ``build`` directory:: +Create and change to a ``build`` directory:: mkdir build && cd build -In the previous section we talked about target GPU architecture. Once you decide which one is right -for you, run CMake using the right ``GPU_TARGETS`` flag:: +The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ @@ -109,26 +78,25 @@ for you, run CMake using the right ``GPU_TARGETS`` flag:: -D BUILD_DEV=OFF \ -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -If everything went well the CMake run will end up with:: +If everything goes well the CMake command will return:: -- Configuring done -- Generating done -- Build files have been written to: "/root/workspace/composable_kernel/build" -Finally, we can build examples and tests:: +Finally, you can build examples and tests:: make -j examples tests -If everything is smooth, you'll see:: +When complete you should see:: Scanning dependencies of target tests [100%] Built target tests ---------------------------- Run examples and tests ---------------------------- +====================== -Examples are listed as test cases as well, so we can run all examples and tests with:: +Examples are listed as test cases as well, so you can run all examples and tests with:: ctest @@ -136,38 +104,32 @@ You can check the list of all tests by running:: ctest -N -We can also run them separately, here is a separate example execution:: +You can also run examples separately as shown in the following example execution:: ./bin/example_gemm_xdl_fp16 1 1 1 -The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, -initialize matrices with integers and benchmark the kernel execution. You can play around with -these parameters and see how output and execution results change. +The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. -If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see -something like:: +If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} + b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 + Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 -Meanwhile, running it on a `gfx1030` device should result in:: +However, running it on a `gfx1030` device should result in the following:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a +Don't worry, some operators are supported on `gfx1030` architecture, so you can run a separate example like:: ./bin/example_gemm_dl_fp16 1 1 1 -and it should result in something nice similar to:: +and it should return something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} @@ -182,12 +144,9 @@ and it should result in something nice similar to:: .. note:: - There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of - the newest versions of the library and do not see the above results when running - ``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command - in order to build the operators supported on the `gfx1030` architecture. + A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. -We can also run a separate test:: +You can also run a separate test:: ctest -R test_gemm_fp16 @@ -198,13 +157,9 @@ If everything goes well you should see something like:: 100% tests passed, 0 tests failed out of 1 ------------ Summary ------------ +======= -In this tutorial we took the first look at the Composable Kernel library, built it on your system -and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different -configs to find out the best one for your hardware and task. +In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. -P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better -ways to spend your money for sure! +P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. diff --git a/docs/what-is-ck.rst b/docs/what-is-ck.rst new file mode 100644 index 0000000000000000000000000000000000000000..f0b51c48f81a8ca43595cf7dda2c1b91a083733a --- /dev/null +++ b/docs/what-is-ck.rst @@ -0,0 +1,41 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _what-is-ck: + +******************************************************************** +What is the Composable Kernel library +******************************************************************** + + +Methodology +=========== + +The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: + +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators using an innovative technique called + "Tensor Coordinate Transformation". + +.. image:: data/ck_component.png + :alt: CK Components + + +Code Structure +============== + +The CK library is structured into 4 layers: + +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. + +.. image:: data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/wrapper.rst b/docs/wrapper.rst new file mode 100644 index 0000000000000000000000000000000000000000..39e2fd0bbd3942f1a9433c12a8e1ed2cb054272c --- /dev/null +++ b/docs/wrapper.rst @@ -0,0 +1,94 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _wrapper: + +******************************************************************** +Wrapper +******************************************************************** + +------------------------------------- +Description +------------------------------------- + + +The CK library provides a lightweight wrapper for more complex operations implemented in +the library. + +Example: + +.. code-block:: c + + const auto shape_4x2x4 = ck::make_tuple(4, ck::make_tuple(2, 4)); + const auto strides_s2x1x8 = ck::make_tuple(2, ck::make_tuple(1, 8)); + const auto layout = ck::wrapper::make_layout(shape_4x2x4, strides_s2x1x8); + + std::array data; + auto tensor = ck::wrapper::make_tensor(&data[0], layout); + + for(ck::index_t w = 0; w < size(tensor); w++) { + tensor(w) = w; + } + + // slice() == slice(0, -1) (whole dimension) + auto tensor_slice = tensor(ck::wrapper::slice(1, 3), ck::make_tuple(ck::wrapper::slice(), ck::wrapper::slice())); + std::cout << "dims:2,(2,4) strides:2,(1,8)" << std::endl; + for(ck::index_t h = 0; h < ck::wrapper::size<0>(tensor_slice); h++) + { + for(ck::index_t w = 0; w < ck::wrapper::size<1>(tensor_slice); w++) + { + std::cout << tensor_slice(h, w) << " "; + } + std::cout << std::endl; + } + +Output:: + + dims:2,(2,4) strides:2,(1,8) + 1 5 9 13 17 21 25 29 + 2 6 10 14 18 22 26 30 + + +Tutorials: + +* `GEMM tutorial `_ + +Advanced examples: + +* `Image to column `_ +* `Basic gemm `_ +* `Optimized gemm `_ + +------------------------------------- +Layout +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Layout + +------------------------------------- +Layout helpers +------------------------------------- + +.. doxygenfile:: layout_utils.hpp + +------------------------------------- +Tensor +------------------------------------- + +.. doxygenstruct:: ck::wrapper::Tensor + +------------------------------------- +Tensor helpers +------------------------------------- + +.. doxygenfile:: tensor_utils.hpp + +.. doxygenfile:: tensor_partition.hpp + +------------------------------------- +Operations +------------------------------------- + +.. doxygenfile:: copy.hpp +.. doxygenfile:: gemm.hpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a5933262a588df5f507b8b006460e6f68a76025e..5b71cd15480062368e14e08b0c635b61cccc04f9 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,71 +1,76 @@ -if(DL_KERNELS) - add_custom_target(example_gemm_dl) - - add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_fp32) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_fp16) - add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int8) - endif() - - if(USE_BITINT_EXTENSION_INT4) +add_custom_target(example_gemm_dl) + +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp32) + +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp16) + +add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) + +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_int8) +if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) - endif(USE_BITINT_EXTENSION_INT4) -endif() + add_example_dependencies(example_gemm_dl example_gemm_dl_int4) +endif(USE_BITINT_EXTENSION_INT4) add_custom_target(example_gemm_xdl) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) - add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) - add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) - add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) - - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + +add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) + +add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + +add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) - endif() - + add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -endif() +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int8) -endif() +add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) + +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -endif() +# FIXME: re-enable this example as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) - if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_f8) - endif() -endif() +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) + +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) + +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) + + add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) + set(target 1) + endif() +endforeach() -add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) diff --git a/example/01_gemm/gemm_dl_dpp8_fp16.cpp b/example/01_gemm/gemm_dl_dpp8_fp16.cpp deleted file mode 100644 index ea0ba390767afd5bb52f2d2b6ce58bb0d6e7f16e..0000000000000000000000000000000000000000 --- a/example/01_gemm/gemm_dl_dpp8_fp16.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = Col; -using BLayout = Row; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e55ae140130c0779c1e81e0bdc84b6724c6bac86..43c0cfe2e014245da149cd6470e1807c7222e5fc 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -43,3 +41,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a9e3f6186cc1525c8d882e9a961ae6ec06f95b2 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc14dcb8eb48f37b7b3597a02407ef440e17ab37 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 54fbd9cdd4983c1a810f3c8a1ac1521b363af690..2338cdc9c12a47542731d8bba47bb5be24370bc8 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -9,13 +9,13 @@ using ADataType = ck::half_t; using BDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = float; +using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using F16 = ck::half_t; using ALayout = Row; -using BLayout = Col; +using BLayout = Row; using CLayout = Row; using AElementOp = PassThrough; @@ -30,7 +30,7 @@ using DeviceGemmInstance0 = ck::tensor_operation::device::DeviceGemmXdl // ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| // ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1>; // // clang-format on // clang-format off @@ -39,7 +39,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance1; diff --git a/example/01_gemm/gemm_xdl_fp16_f8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp similarity index 100% rename from example/01_gemm/gemm_xdl_fp16_f8.cpp rename to example/01_gemm/gemm_xdl_fp16_fp8.cpp diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eba0ea9d110ef2339484daa4a2143e6a6f6cbeb0 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using F16 = ck::half_t; +using F32 = float; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV2< + ALayout, BLayout, CLayout, + F16, F16, F16, F32, F16, + PassThrough, PassThrough, PassThrough, GemmDefault, + 2, 256, + 256, 256, + 32, 8, 4, + 32, 32, + 4, 4, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 4, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::LoopScheduler::Default, ck::PipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_f8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp similarity index 97% rename from example/01_gemm/gemm_xdl_f8.cpp rename to example/01_gemm/gemm_xdl_fp8.cpp index 10159267776f6f33469affc618192e29d62763ac..2d4df3fc130369dcbb71ca5eb004da3f57df7662 100644 --- a/example/01_gemm/gemm_xdl_f8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -7,9 +7,9 @@ using ADataType = ck::f8_t; using BDataType = ck::f8_t; -using CDataType = ck::f8_t; +using CDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = ck::f8_t; +using CShuffleDataType = float; using ALayout = Row; using BLayout = Col; @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b54df8ff3d350c1ef33102052589f589edbe59d5 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::bf8_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index f6238c7aa5040d8e409a139fe8c144835282cdc9..fb4f383fae1e96a555a2c83f0b842aaec588068c 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -44,3 +42,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host:: #include "run_gemm_example.inc" int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +#endif \ No newline at end of file diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d29cb74cd6419956f45416e1f46b6fff76809991 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e99249389ef8d6fc972c0a548cf50810cf46e187 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 9989e45e0c275189c592414a2fc8e2a28b2fc511..d82c42d5a9d83f5c95e19f7b2f23dca7c3d2c5e9 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,10 +1,12 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) + add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +endif() +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") set(target 1) endif() endforeach() @@ -16,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f23ad26522f4a21c18db74fe5dd4468d3ff50e7 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + int alpha_; + int beta_; +}; + +template +using S = ck::Sequence; + +using I8 = std::int8_t; +using I32 = std::int32_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DDataType = I8; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 32, + 16, + 16, + 4, + 16, + 16, + 16, + 1, + 1, + S<2, 16, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 1, + S<4, 1, 8>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 16, + 2, + 1, + 1, + 1, + S<1, 16, 1, 2>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + int alpha = 1; + int beta = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index a247a052cba40410e8879a57a0e3c35a57b5645f..2f5cba924dfd044a1ff0717d75a8b6e7efc3d16b 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 15ec62c89fcd39e24e7bb1ae5f8e7d366f2138d6..33ac1e7e7729717ab9c9542ef58945bae7bfab18 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,28 +1,36 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_add_add_fastgelu_xdl) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + set(target 1) endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) +endforeach() + +set(gpu_list "") + +list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) + set(target 1) endif() - set(target 1) - endif() endforeach() \ No newline at end of file diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index f206bbeb411bb0de0144015d70a892b5587d6f0e..1d0b0f7861858f44dbd692e52f206e157a1ebe6a 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -58,3 +56,4 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_add_add_fastgelu_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_add_add_fastgelu_example(argc, argv); } diff --git a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc index cb3147bcd718c115e47163af58e813bd2cf39bfe..cb0271c81f1d3feac72004da7e37a288fc97d463 100644 --- a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc +++ b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -105,7 +105,8 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC if(!device_op.IsSupportedArgument(argument)) { - throw std::runtime_error("wrong! this device_op instance does not support this problem"); + std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; + return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 1af1e6c8582f6f0f4ccf3902dc1dd0c1d5bd02d5..f9903bfe0390725997e166be78f8b0650eaf80fb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -2,34 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - endif() - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - endif() set(target 1) endif() endforeach() -if(DL_KERNELS) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) - endif() -endif() +add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) +add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) +add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index 74cf91d16017bd24096dac6348e5a8815856a51a..b6bb03e1e58d04f1804387be42ee58e5fe7800a8 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index f6d69bafd48a4b8bff968c2fb8600793b1310294..064a971478f2f8445de318122f4a346834f70a0c 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 6c3171f6157ac4a83a488af12089a97accfd2c7c..36517e569dd5a0574190a01c317fec10a718badf 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 9977a496d229876c083e0afe99e17ff0fa5fca05..659283c3f05b8a23c639eb1420a9f67f54d788d6 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index bf084b3cc0b6ceab326b2b21e963e0ce2d4bdcad..0180e6e71890d8db2eb2f6ac6423003ed95bbb33 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -3,7 +3,7 @@ #include "convnd_fwd_common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" @@ -27,7 +27,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index e7d941ae6b3dddee9f8dfa357cf3426045d0e828..222a3b7c0be18e426860d51ec44fc758c1f22d6b 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,28 +1,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() \ No newline at end of file + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_reduce_xdl) + + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) + + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) + + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp index bf7127502faac94e6858d37c389f316b7ead35a6..ce8ff0b49a9b1d72ab6fca2914db680ec9497c8c 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #define BUILD_INT4_EXAMPLE @@ -24,3 +22,4 @@ using RsDataType = ck::Tuple; #include "run_convnd_fwd_max_example.inc" int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } +#endif diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md index 76d28527bb8aa521611e54d172b427aabf9fb998..bcffa684c8f2c5c34efbf4710656b0124dfa8a42 100644 --- a/example/12_reduce/README.md +++ b/example/12_reduce/README.md @@ -2,7 +2,7 @@ ## Run ```example_reduce_blockwise``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4) @@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ## Run ```example_reduce_multiblock_atomic_add``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp32, 1: fp64) diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index d0f356757b3845cb1a72708f0cd5c367eb1042fc..e2a923cded8fe2ceecabfaff5a3201817ae3700a 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) -endif() +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 3b3ad80dd826fef38f8d4ed36d504538ae3e263a..9793e8b8a096988caac089a63651230c50793e65 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,9 +1,5 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) # dlops -if(DL_KERNELS) - add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -endif() - +add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) # xdlops list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) @@ -14,4 +10,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() \ No newline at end of file diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index dca60b0e739d556ee9c6cb6d463884c8b308866a..84040fcf5cb939688f7a7b03095321b15548c400 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,26 +1,32 @@ add_custom_target(example_grouped_gemm_xdl) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) - add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) - add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_multiple_d_dl_fp16 - example_grouped_gemm_xdl_splitk_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -endif() +add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) + +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) + +add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) + +add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) + +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp similarity index 100% rename from example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a193fc39ba637cbd41df4743e0157ca40c38407b --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {0}}); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1feafce3787722f6400c07f95031423fa315d7 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9fd63cba77ac909cd6576e2ddd19f6c325473812 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(128); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 743ab96be6291e617cb012957865a6897bd61a91..9f8f6cb1e49d7a84f8a1bb596e568131874cf63f 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,13 +66,11 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; - problem_size.Ms = { - 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; - for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ns.push_back(768); - problem_size.Ks.push_back(4608); + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 00786d34a3af79a6d896e58b4fa12d9bbf8a7aa3..5955e1d6cb4589d5b71b7bd6d66f26a894423584 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,47 +1,48 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - endif() - - add_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_reduce_xdl) + add_custom_target(example_gemm_reduce_xdl_max) + add_custom_target(example_gemm_reduce_xdl_mean_meansquare) + add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + + add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + endif() + set(target 1) + endif() endforeach() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index e187bd4337d53e9c1e7b2ecedf023169386825e3..7c6d10d8a0d0375dd5bb8ba498b81a08d91bd7cd 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,15 +1,16 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + endif() set(target 1) endif() endforeach() - if(DL_KERNELS) - add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) - endif() + +add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) endif() diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index a1bb398af0640001f7fba2192e47ffc03e5ce3a0..94ed129dc03664043b1a0295f9f1dcfb38a77002 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index e363dc5c12dd84a36a67e969b06d29a179679f94..62295c57eb3040cbf16547ec277f49a5e44e9a44 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -272,15 +272,14 @@ int main(int argc, char* argv[]) { for(int m = 0; m < M; ++m) { - auto reduce0_acc = reduce0_op.GetIdentityValue(); - auto reduce1_acc = reduce1_op.GetIdentityValue(); - + auto reduce0_acc = reduce0_op.GetIdentityValue(); + auto reduce1_acc = reduce1_op.GetIdentityValue(); + ReduceAccDataType d0_val = 0; + ReduceAccDataType d1_val = 0; for(int n = 0; n < N; ++n) { auto c_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); - ReduceAccDataType d0_val; - ReduceAccDataType d1_val; UnaryIdenticElementOp{}(d0_val, c_val); UnarySquareElementOp{}(d1_val, c_val); diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index d649567ed2c46ce05073d731a435fbf65c034b73..c28fca6fa296273fb2cbd76ab9b9c58a2832e050 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,24 +1,29 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - endif() - set(target 1) - endif() + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) + + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + set(target 1) + endif() + + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) + set(target 1) + endif() endforeach() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_custom_target(example_grouped_conv_bwd_weight_dl) - add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) - endif() -endif() \ No newline at end of file +add_custom_target(example_grouped_conv_bwd_weight_dl) + +add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 15727495f0f1f8dfedcb24a0ffa02d3d7aea67e0..1d2206ff72a7f64b23c79bb4874ace852a1c16c3 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -23,6 +23,12 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif template using S = ck::Sequence; @@ -40,25 +46,21 @@ struct CommonLayoutSetting using OutputLayout = OutputLay; }; -template -struct CommonLayoutSettingSelector; - namespace ctl = ck::tensor_layout::convolution; - -template <> -struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<2> final - : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<3> final - : CommonLayoutSetting +template +struct CommonLayoutSettingSelector + : CommonLayoutSetting>, + ck::tuple_element_t>, + ck::tuple_element_t>> { }; @@ -78,10 +80,10 @@ struct ExecutionConfig final bool time_kernel = false; }; -#define DefaultConvParam \ - ck::utils::conv::ConvParam \ - { \ - 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \ } inline void print_help_msg() diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index 375c309e1c93d30a6d7167d524613c9386510551..44ba3d8e02498ce72c6f0f4f8e9bdcaaf9ef1760 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -3,7 +3,7 @@ #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" using InDataType = F16; using WeiDataType = F16; @@ -15,45 +15,84 @@ using WeiElementOp = PassThrough; using OutElementOp = PassThrough; template -using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdWeightDefault, // ConvBackwardWeightSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // K0PerBlock - 2, // K1 - 4, // M1PerThread - 4, // N1PerThread - 1, // KPerThread - S<8, 2>, // M1N1ThreadClusterM1Xs - S<8, 2>, // M1N1ThreadClusterN1Xs - S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 - S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder - S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder - S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 - S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 - S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder - S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - 4>; // CThreadTransferDstScalarPerVector +using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + NDimSpatial, // NDimSpatial + ck::tuple_element_t>, // InLayout + ck::tuple_element_t>, // WeiLayout + ck::tuple_element_t>, // OutLayout + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 2, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder + S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder + S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c2950b17e802b9418c1614eb8c8c34168929c3b --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + NDimSpatial, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 2, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // ABlockTransferSrcAccessOrder + 1, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 4, + 2, + S<1, 32, 1, 8>, + 1>; + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 31e277b5c709a44a3495348c2e9419dafeb7611c..80b97249304f1cb8403246785335d4c5c98784b1 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -67,6 +67,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 69c831cc54eef84fc0b7ea46f9961573bf3efff4..5c1c97d615df22a41a8331c800e499a1b433b047 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -66,6 +66,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d0d58eb2a3f227dbfbc9ad077f29d4f4102e677 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; +using ComputeTypeA = BF8; +using ComputeTypeB = F8; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CBlockTransferScalarPerVector_NWaveNPerXdl + ComputeTypeA, // ComputeTypeA + ComputeTypeB>; // ComputeTypeB + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 29ce0324abecbbfbc53864871fa6cf558339bbc2..73d15f3ea8c71a138a26bf9cf9aa58cd8fcf1e1c 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,33 +1,12 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -template -using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; - template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - ck::index_t split_k; - // Set split_k = 2 for xdl op, split_k = 1 for dl - // Dl op doesn't support split_k > 1 - // TODO: Add Dl op split_k > 1 support - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) - { - split_k = 2; - } - else - { - split_k = 1; - } + // Dl and WMMA ops don't support split_k > 1 + constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< @@ -58,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); @@ -125,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl - << "DeviceOp: " << conv.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); if(config.do_verification) { @@ -160,25 +128,18 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); } - return true; -} + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); -bool run_grouped_conv_bwd_weight_example(int argc, char* argv[]) -{ - ExecutionConfig config; - ck::utils::conv::ConvParam conv_param = DefaultConvParam; + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); - if(!parse_cmd_args(argc, argv, config, conv_param)) - { - return false; - } + float tflops = static_cast(flop) / 1.E9 / avg_time; - switch(conv_param.num_dim_spatial_) - { - case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param); - case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param); - case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param); - } + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; - return false; + return true; } diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 6a6735efd626bc1931fa9f7e6651dceae713da11..e231bc619b781c19f447b7fa7442f2beb58331cb 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() + diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index fc58ca19f8673aa0b2205df21350f7ddf0b92196..6a92e9a2f50c8a5f1b80228307ad946188ea3c75 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor& h_m_n, BetaDataType, HDataType, AccDataType, + AccDataType, HElementOp, 2, 1>; Tensor e_m_n(HostTensorDescriptor{M, N}); Tensor c_m_n(HostTensorDescriptor{M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); auto ref_gemm = ReferenceGemm{}; auto ref_gemm_invoker = ref_gemm.MakeInvoker(); @@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 854f07fda6aad404fda546b40bef1c8b45c766d5..44585b11d04e762f58cc20c1816200857eef0fa6 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,22 +1,18 @@ add_custom_target(example_cgemm_xdl) -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) +add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) + +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) + add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) -endif() +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) + +add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 48a3b58ff53a712492e82ffce5632c08b6937ba2..4cb45be7c90182928e6444c3d923c65b28add384 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,21 +1,18 @@ add_custom_target(example_batched_gemm_xdl) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) -endif() + +add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) + +add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) + +add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp similarity index 100% rename from example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp rename to example/24_batched_gemm/batched_gemm_xdl_bf16.cpp diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index eb274b233802401628abe0f030159da75b8f5950..cbc3c007bc22622554fd93f4d8a829c9ab666dc4 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,4 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) - add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) +add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 6cab88b13ef5761a16cedcd9d86d7f855fe5b99e..1a0489ce9ca9efecf6b0267258944ac8c17cc05a 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,8 +1,52 @@ -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) - add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -endif() -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) - add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -endif() +add_custom_target(example_contraction) +add_custom_target(example_contraction_scale) +add_custom_target(example_contraction_bilinear) + +# FP32 +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) + +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) + +add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) + +add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) + +add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) + +add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) + +# FP64 +add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) + +add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) + +add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) + +# FP16 +add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) + +# BF16 +add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) +add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) + +add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) +add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) + +add_dependencies(example_contraction example_contraction_scale) +add_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp new file mode 100644 index 0000000000000000000000000000000000000000..480ca5a0af67bcf8db871fcd79e795497901c5ab --- /dev/null +++ b/example/26_contraction/common_instances.hpp @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using F64 = double; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Generic instances for fp32, fp16 and bf16 data types. +template +// clang-format off +using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; +// clang-format on + +// Fp64 instances. +template +// clang-format off +using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on + +template +// clang-format off +using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>; +// clang-format on diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..77ec5669a76e093fd351f82b64d77ee40a147ad4 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DDataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..25c4d9eac467f694a61b3f697dd6fd99848353b4 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index 78522160c85a1024ec510197e5ab1069e6f077cf..d440e7394e5da2794f3a959ea769587cb60afc7f 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F32; using BDataType = F32; @@ -32,6 +13,7 @@ using CShuffleDataType = F32; using DDataType = F32; using DsDataType = ck::Tuple; using EDataType = F32; +using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -41,253 +23,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceKNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -// clang-format on +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; using DeviceOpInstance = DeviceOpInstanceKKNN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_bilinear_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b748949bd53ce0954e5e62440b74c3d1c5aa235 --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..334cd615d17e746e4112e54aad8a0107c0f67cea --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 6cceed5bc11ed8eb79568a0397bbf7e1b93fd33d..86ea7e0212d62a5f0e6773bbba0b3bffaad005c7 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F64 = double; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F64; using BDataType = F64; @@ -32,6 +13,7 @@ using CShuffleDataType = F64; using DDataType = F64; using DsDataType = ck::Tuple; using EDataType = F64; +using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -41,253 +23,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceKNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMNNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -// clang-format on +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; using DeviceOpInstance = DeviceOpInstanceKKNN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_bilinear_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b0a2a9a10fe59abeac58d2ec3e43af012b51d82b --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F32; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +#include "run_contraction_bilinear_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32652fc6f3c8856505c3afc3c764fb0d08f23f48 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = BF16; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4797e24008984f281718ba4e8e6cdea28555714 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 1574f5d18fb7f5f193cd4adb1c1a88616d7b1437..fe984b29d625b69ecf5555a3045f5b925752a1cd 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F32; using BDataType = F32; @@ -31,6 +12,7 @@ using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; using EDataType = F32; +using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -40,237 +22,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -using DeviceOpInstanceMNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -// clang-format on +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; using DeviceOpInstance = DeviceOpInstanceKKN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_scale_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83c8fec179d1cbdf4121c9c16a28a024458c0695 --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = BF16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..27e42e82037610de7f9d5bdfbc1f936e56d8733f --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 3dacc708877d838dd5fb0113b5d560885b0551ae..37374780b6864cf1b23391749f41ffc55236622a 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,29 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include -#include -#include -#include - #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -template -using S = ck::Sequence; - -using F64 = double; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +#include "common_instances.hpp" using ADataType = F64; using BDataType = F64; @@ -31,6 +12,7 @@ using AccDataType = F64; using CShuffleDataType = F64; using DsDataType = ck::Tuple<>; using EDataType = F64; +using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -40,237 +22,64 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// clang-format off -using DeviceOpInstanceKKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceKNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMKN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -using DeviceOpInstanceMNN = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -// clang-format on +using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; using DeviceOpInstance = DeviceOpInstanceKKN; -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } +#include "run_contraction_scale_example.inc" - return 0; -} +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5dbeb231542539d989de97c3c64b16426876664d --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "common_instances.hpp" + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F32; +using CShuffleDataType = F64; +using DsDataType = ck::Tuple<>; +using EDataType = F64; +using ComputeDataType = F32; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; + +using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; + +using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; + +using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; + +using DeviceOpInstance = DeviceOpInstanceKKN; + +#include "run_contraction_scale_example.inc" + +int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..7b9a01e0ae80b5bd5035fe213a3fe065e8a7a408 --- /dev/null +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_contraction_bilinear_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..65ca182da2cf7c5f97688f3a7c6d70b19cec016f --- /dev/null +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +int run_contraction_scale_example(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt deleted file mode 100644 index 9cb2cd0766ea8c5aaa9154e87457763cc38372ab..0000000000000000000000000000000000000000 --- a/example/27_layernorm/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) - add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) -endif() diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp deleted file mode 100644 index bb8b954f0acaf2e9d748104a05561f9cb02fc9a0..0000000000000000000000000000000000000000 --- a/example/27_layernorm/layernorm_fp16.cpp +++ /dev/null @@ -1,39 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector -#include "run_layernorm_example.inc" - -int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/run_layernorm_example.inc b/example/27_layernorm/run_layernorm_example.inc deleted file mode 100644 index 95200b540aa9f9704c8fad785b352bdf779c4094..0000000000000000000000000000000000000000 --- a/example/27_layernorm/run_layernorm_example.inc +++ /dev/null @@ -1,96 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -template -int run_groupnorm_example() -{ - bool time_kernel = false; - - ck::index_t M = 1024; - ck::index_t N = 1024; - ck::index_t Stride = N; - - auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { - return HostTensorDescriptor({len}, {stride}); - }; - - auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { - using namespace ck::literals; - - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - }; - - Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); - Tensor gamma(f_host_tensor_descriptor1d(N, 1)); - Tensor beta(f_host_tensor_descriptor1d(N, 1)); - Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); - - x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); - - x_dev.ToDevice(x.mData.data()); - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = DeviceInstance{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - {M, N}, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - {0, 1}, - {0, 1}, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - {1}, - 1e-4, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); - - if(!device_instance.IsSupportedArgument(argument_ptr.get())) - { - std::cout << "The runtime parameters are not supported" << std::endl; - return 1; - }; - - size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); - DeviceMem workspace_dev(workspace_sz); - device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); - using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; - - ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); - } - - return (pass ? 0 : 1); -} diff --git a/example/27_layernorm2d_fwd/CMakeLists.txt b/example/27_layernorm2d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..639bd9c400e428599d606eb3e43fd3078e8382e2 --- /dev/null +++ b/example/27_layernorm2d_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_layernorm2d_fwd_fp16 layernorm2d_fwd_fp16.cpp) +add_example_executable(example_layernorm2d_fwd_splitk_fp16 layernorm2d_fwd_splitk_fp16.cpp) diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm2d_fwd/common.hpp similarity index 94% rename from example/27_layernorm/common.hpp rename to example/27_layernorm2d_fwd/common.hpp index 62a71713df84351fea902cdcf3275786b60f5a23..6c7b99f89f9930cdcb84bda674f96ab3a1237241 100644 --- a/example/27_layernorm/common.hpp +++ b/example/27_layernorm2d_fwd/common.hpp @@ -10,8 +10,8 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" diff --git a/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp b/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..20db24f56d38b22d20100a03f51f83fd21815125 --- /dev/null +++ b/example/27_layernorm2d_fwd/layernorm2d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector +#include "run_layernorm_example.inc" + +int main() { return run_layernorm2d_fwd_example(); } diff --git a/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp b/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a57082c79f2e5c56527c0396952ebba2e9c7bd9 --- /dev/null +++ b/example/27_layernorm2d_fwd/layernorm2d_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + PassThrough, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterM + 32, // ClusterK + 1, // SliceM + 8, // SliceK + 1, // XYVectorDim (0=M, 1=K) + 8, // XScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 8, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 8, // BetaScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_layernorm_example.inc" + +int main() { return run_layernorm2d_fwd_example(); } diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..23608a1eea01bffa20ebc1938ce4e7b4769e26a1 --- /dev/null +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_layernorm2d_fwd_example() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + + Tensor x({M, N}); + Tensor gamma({N}); + Tensor beta({N}); + Tensor y({M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {M, N}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 1}, + {0, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y({M, N}); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 2fda1f62a9a94c69bde8f08e8c81167b2951c4d6..44ab16894ce054a4b25b2efd12ebac9152b3bd58 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) -endif() +add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 09c3e6c6085c233aeac87374e6996cade4dac142..32a87dd200fc5f9048fc2eb1f56f343fe5586c64 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,7 +1,5 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) - endif() +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index e37413c09880de71b5bc10b15624fb98e75bf4ba..3a8c2ef52fac50d10bf2d1167ef1d4819281d312 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -3,44 +3,38 @@ list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - endif() - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 - - set(target 1) - endif() + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_custom_target(example_grouped_conv_fwd_multiple_d) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) + + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) + + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + endif() # USE_BITINT_EXTENSION_INT4 + + set(target 1) + endif() endforeach() set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - endif() - set(target 1) - endif() + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + set(target 1) + endif() endforeach() diff --git a/example/30_grouped_conv_fwd_multiple_d/README.md b/example/30_grouped_conv_fwd_multiple_d/README.md index 739a0425a8c29419c3e31e433815c5230bd1a5cb..7a0cb2d0e4ba47b79b8e524c5ac29eab08dad57b 100644 --- a/example/30_grouped_conv_fwd_multiple_d/README.md +++ b/example/30_grouped_conv_fwd_multiple_d/README.md @@ -4,7 +4,7 @@ arg1: verification (0=no, 1=yes) arg2: initialization (0=no init, 1=integer value, 2=decimal value) arg3: time kernel (0=no, 1=yes) Following arguments (depending on number of spatial dims): - Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) + Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D) G, N, K, C, , (ie Y, X for 2D) , (ie Hi, Wi for 2D) @@ -26,5 +26,5 @@ out: dim 5, lengths {1, 128, 256, 36, 36}, strides {256, 331776, 1, 9216, 256} launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} Warm up 1 time Start running 10 times... -Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleD_Xdl_CShuffle<256, 128, 256, 16, Default> +Perf: 1.55981 ms, 94.0927 TFlops, 213.868 GB/s, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<256, 128, 256, 16, Default> ``` diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e60ebee6e4fe2e08db432816b75bbbd4c388946c..39463614f5b34dde88992fe43e32b1985c975a38 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -12,7 +12,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 5494563fdd568d1f51ec9ee9042d094096e8fc8a..6f91d51a5fd7195472805a84d1f123b208fe9f67 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include "common.hpp" @@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; #include "run_grouped_conv_fwd_bias_relu_add_example.inc" int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } +#endif diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index eb242203eaa65ba7b8fca45c5f1e3e1b9d96c409..e3370b880bb29e78a88e5ed4ca1fcdcd14afa2f8 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -34,7 +34,7 @@ using ResidualLayout = typename LayoutSettingSelector::ResidualLayo template using DeviceConvFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InputLayout, WeightLayout, diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index 58ed69182e0622751dc3fa2823fcb7cfa5b279a8..da65bb1886fe11c8fdf888e772fde6892e7b3d29 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -3,7 +3,7 @@ template using DeviceConvFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InputLayout, WeightLayout, diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 2074520f8c06d527be4ff1ea6a87c388cfbd6391..93f16c945f22f4180d6a79477be3fa0aca130e5e 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,11 @@ list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) + set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - endif() + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) - endif() + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index d166214c3376cd90afd11796a4cb85ea28421861..2caee6b8dc5dc0a21bcc7b6526b463325cbce3d1 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o Gemm1 */ -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_batched_gemm_gemm_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 0463bf6bd31483ada687f17ff175e0195404623c..2a24abf094b09d40b7feef6e9a46a782dc7610d8 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,24 +1,23 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -endif() - add_custom_target(example_gemm_scale_softmax_gemm) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) - add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) - add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -endif() + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) + diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 251a9b93c57e3b24ae9a1305bbe7607eb696817f..5277b32f63e038a8420c9eb45e1c9d1f50106f55 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -3,26 +3,30 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_splitK_gemm_xdl) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - endif() + + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) + + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + + add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) + + add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) + + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) + + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) endif() + set(target 1) endif() endforeach() diff --git a/example/35_splitK_gemm/run_splitK_gemm_example.inc b/example/35_splitK_gemm/run_splitK_gemm_example.inc index e9bd5c552dea937deb78859b94b731eec6ec03b0..e3690984abc08cff2e5a450f9f6b6ce13217b883 100644 --- a/example/35_splitK_gemm/run_splitK_gemm_example.inc +++ b/example/35_splitK_gemm/run_splitK_gemm_example.inc @@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con if(config.time_kernel) { - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp similarity index 100% rename from example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp rename to example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index 74fb16e15b019e145cdf13549254be1b09fbd8ca..dc54bc30efeb9c5233e2af1d55381fc102d8e3f5 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -42,7 +42,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::KPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b93639e6c11dc430e249645e39b944d4637e99c3 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Default, ADataType, BDataType>; + +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..97a3f89e5ef3f90f5f431a20726a19d3de1d11d6 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +#define DIRECT_LOAD 1 + +#if DIRECT_LOAD +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#endif + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +#if DIRECT_LOAD + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle_LdsDirectLoad + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; +// clang-format on + +#else + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#endif + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index 36bb5720d53c1a738323f50d0e855d50d8996047..a9be3a7108fee6075dd7ae81f7536b072e655c0b 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 3821f8aaca74a3b028fc3917c10e5f279375c3fe..1ae179e950db0fea8e9f008d449cb8bfe23b54fd 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,15 +1,27 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) - - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) - set(target 1) - endif() + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + + add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) + + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) + + set(target 1) + endif() +endforeach() + +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + + add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) + + set(target 1) + endif() endforeach() -endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index ca824b1075ced161760892fe6de53767b1079ccc..ebb1c606cb6c61de2ce41c846c06b1f1f200f1c4 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp similarity index 97% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index a3533bb4cc4b4874f0e7ae9f333347913d981ba0..4c28e25e019d41d2a3cb55e410286ab616dc4722 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5baa521501704a2ea603f5b975cd6b392285cfb5 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < 2,OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp similarity index 97% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index fb688b6f3f15383c2920a12e8902c2d4cf05a109..b1554412b1adca12cd49231a1c59cd6f34462ebf 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 5b43de9725aa3d535851926884fc6790233c6dbb..8b850c89a935ed194cad918a2bdbcf4b0c77d739 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,11 +1,10 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(example_permute) +add_custom_target(example_permute) - add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) - add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) - add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) +add_example_dependencies(example_permute example_permute_1xHxW_fp16) - add_dependencies(example_permute example_permute_1xHxW_fp16) - add_dependencies(example_permute example_permute_NxHxW_fp16) - add_dependencies(example_permute example_permute_HxWx4_fp16) -endif() +add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) +add_example_dependencies(example_permute example_permute_NxHxW_fp16) + +add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_dependencies(example_permute example_permute_HxWx4_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 55464957ac62e831e332e6b88b97e283e22866c8..2d804cafe9e7c5cc8cb2c6ca783fc9ffc8b1ab72 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() - if(DL_KERNELS) # Conv perlayer quantization add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) # Conv perchannel quantization @@ -24,5 +22,3 @@ endforeach() add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) # Conv + bias + tanh perchannel quantization add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) - endif() -endif() \ No newline at end of file diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index 06c839e4e226e082e9754f12ce39cececbd36269..8c0049b0fa84dd836db25d2751c3f0d2aeca07a3 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -33,7 +33,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 7a9b42d39f971ab1c36476f4c0ead172cb6d2606..e18c123f7c3a18a5d6017bc5cf045c2060da18e8 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -31,7 +31,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index 3495636297d5d83fb3ccc7d236e6135de9b715fd..53f810cc9ec078c7ee4da9ee3961fc7cac9f84a1 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -31,7 +31,7 @@ template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index 2611337254212b326f6eef540c6a4eca5cd19261..9db6e201ddd14704fd7b4903d4510a2c91fafee4 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" using InDataType = int8_t; using WeiDataType = int8_t; @@ -26,7 +26,7 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio template using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, InLayout, WeiLayout, diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 085d90a9e9c5cc773c9935ba954853444e35740f..ae251e88d205a04968eab430b52c004a1136d853 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - endif() + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) - endif() + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 80f6e9ae05712b1df4b05b1bb75fbff62d9ca70a..cf7b1ce3a813303841c47ae8b3084348563b41bb 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,9 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -#error Should compile this file with ck::int4_t support -#endif +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #include #include @@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); #endif int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +#endif diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt deleted file mode 100644 index bc2246a2bfd39ad933f42b5b5f47d135f6278f76..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) - add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) - add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) -endif() diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp deleted file mode 100644 index cc107b63dcd87735ba40ebbe7633c61993406166..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ /dev/null @@ -1,56 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; - -struct YElementOp -{ - template - __host__ __device__ void operator()(T& y, const T& x) const - { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, - "Data type is not supported by this operation!"); - - T a; - - ck::tensor_operation::element_wise::Sigmoid{}(a, x); - - y = x * a; - }; -}; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -#include "run_groupnorm_example.inc" - -int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp deleted file mode 100644 index 363f22ed4c0015c9a96aa98a8b6b5302978e3d25..0000000000000000000000000000000000000000 --- a/example/42_groupnorm/groupnorm_swish_fp16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -constexpr int Rank = 5; -constexpr int NumReduceDim = 3; - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -#include "run_groupnorm_example.inc" - -int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/CMakeLists.txt b/example/42_groupnorm_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7d08baccd0d6938797f54aa83c4afabebdb26a3a --- /dev/null +++ b/example/42_groupnorm_fwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_groupnorm_fwd_sigmoid_mul_fp16 groupnorm_fwd_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_fwd_splitk_fp16 groupnorm_fwd_splitk_fp16.cpp) +add_example_executable(example_groupnorm_fwd_swish_fp16 groupnorm_fwd_swish_fp16.cpp) diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm_fwd/common.hpp similarity index 95% rename from example/42_groupnorm/common.hpp rename to example/42_groupnorm_fwd/common.hpp index c8f91eb53b5b943bfbdd264b842cf7762a176b82..038a8c0f1fea900754fa1ed27c6b5210e92072e9 100644 --- a/example/42_groupnorm/common.hpp +++ b/example/42_groupnorm_fwd/common.hpp @@ -11,8 +11,8 @@ #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/library/utility/fill.hpp" diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..15c02b8213c9f70bf919ec4c6eb083d7cf1573bf --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_sigmoid_mul_fp16.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; + +#define SAVE_MEAN_INV_STD + +struct YElementOp +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + X a; + + ck::tensor_operation::element_wise::Sigmoid{}(a, x); + + y = ck::type_convert(x * a); + }; +}; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..37fdf1b253d677c7312c09192bbbe1c521768beb --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + YElementOp, + Rank, + NumReduceDim, + 256, // BlockSize + 1, // ClusterM + 256, // ClusterK + 1, // SliceM + 16, // SliceK + 1, // SrcVecDim (0=M, 1=K) + 2, // SrcScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 2, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 2, // BetaScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp b/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d17264cef8f1c42c02ed94b76bef22db7059b53 --- /dev/null +++ b/example/42_groupnorm_fwd/groupnorm_fwd_swish_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_fwd_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_fwd_example(argc, argv); } diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc similarity index 57% rename from example/42_groupnorm/run_groupnorm_example.inc rename to example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 16065c8d46a42433416e21e9607dc3fcb708d817..853ff791a6e8d0d07cf1584243e267f512291530 100644 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -3,7 +3,7 @@ #pragma once -int run_groupnorm_example(int argc, char* argv[]) +int run_groupnorm_fwd_example(int argc, char* argv[]) { ck::index_t N = 32; ck::index_t H = 16; @@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[]) Tensor y({N, H, W, G, C}); Tensor gamma({G, C}); Tensor beta({G, C}); + Tensor save_mean({N, G}); + Tensor save_inv_std({N, G}); ck::utils::FillUniformDistribution{0.f, 1.f}(x); ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); @@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[]) DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[]) {0, 0, 0, C, 1}, {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif y_element_op); if(!device_instance.IsSupportedArgument(argument_ptr.get())) @@ -92,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[]) bool pass = true; { Tensor host_y({N, H, W, G, C}); - using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + Tensor host_save_mean(HostTensorDescriptor{N, G}); + Tensor host_save_inv_std(HostTensorDescriptor{N, G}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + y_element_op, + {N, H, W, G, C}, + 1e-6); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif } return (pass ? 0 : 1); diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index 7e070f5357999b73ce647489c37613a4e52329f8..c29f18f1627ef20fe69cb3751d3a7766ffbd236c 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) -endif() +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 877a82031598ba63c3c331b0bf7df34a2c07d5a7..a963399dc778035b3521c769596257654f80ea91 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,4 +1,10 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) - add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) +add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permute_4D_fp32_row.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) +add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_permute elementwise_permute.cpp) +if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) + add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) endif() diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24e161c6d3f8518bc2063c6113fc4c67784a2a88 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + 5, // NumDim + 8, // MPerThread + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) +{ + for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) + for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) + for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) + for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) + for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) + { + auto a_val = A_ncdhw(n, c, d, h, w); + functor(B_ndhwc(n, d, h, w, c), a_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector ncdhw = {16, 8, 8, 8, 8}; + std::vector ndhwc = {16, 8, 8, 8, 8}; + Tensor a(ncdhw); + Tensor b(ndhwc); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + /**std::array a_strides = { + static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), + static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4]), + static_cast(ncdhw[3] * ncdhw[4]), + static_cast(ncdhw[4]), + 1}; + std::array b_strides = { + static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), + static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), + 1, + static_cast(ndhwc[3] * ndhwc[4]), + static_cast(ndhwc[4])};**/ + + std::array a_strides = { + static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), + static_cast(ncdhw[3] * ncdhw[4]), + static_cast(ncdhw[4]), + 1, + static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4])}; + + std::array b_strides = { + static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), + static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), + static_cast(ndhwc[3] * ndhwc[4]), + static_cast(ndhwc[4]), + 1}; + ck::ranges::copy(ncdhw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (ncdhw): " << a.mDesc << std::endl; + std::cout << "B (ndhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + + std::size_t num_btype = + sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + + sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ndhwc); + host_elementwise4D(host_b, a, PassThrough{}); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3aca57c350e4d6d4a073fed9fa2512bd12b20bc --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F32; +using BDataType = F32; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwise3dImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + 2, // NumDim_m, {N, C} + 2, // NumDim_n, {H, W} + 1, // NumDim_k, {D} + 4, // MPerThread + 4, // NPerThread + 4, // KPerThread + ck::Sequence<4>, // InScalarPerVectorSeq + ck::Sequence<4>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) +{ + for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) + for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) + for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) + for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) + for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) + { + auto a_val = A_ncdhw(n, c, d, h, w); + functor(B_ndhwc(n, d, h, w, c), a_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + const int N = 4; + const int C = 16; + const int H = 32; + const int W = 5; + const int D = 16; + + std::vector ncdhw = {N, C, D, H, W}; + std::vector ndhwc = {N, D, H, W, C}; + Tensor a(ncdhw); + Tensor b(ndhwc); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (ncdhw): " << a.mDesc << std::endl; + std::cout << "B (ndhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + + std::size_t num_btype = + sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + + sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ndhwc); + host_elementwise4D(host_b, a, PassThrough{}); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 2ceda86839abc37bd71fc1179d12be51d42e9a20..8e9bc64ab6fade4e9353264a37de42d3d5fac4bb 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include @@ -19,13 +22,13 @@ using BDataType = F16; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, - ck::Tuple, - PassThrough, - 4, - 8, - ck::Sequence<8>, - ck::Sequence<1>>; + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // Elementwise op + 4, // NumDim + 8, // MPerThread + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq template void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) @@ -99,7 +102,6 @@ int main() std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - bool pass = true; if(do_verification) diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 6b94a5d46f5ad20474900365ab6397ef48d87769..30231a3758eedfbc0f0e764543cb2bd88ea35b69 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + #include #include @@ -17,15 +20,15 @@ using BDataType = F16; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwise2dImpl, - ck::Tuple, - PassThrough, - 3, // NumDim_M - 1, // NumDim_N - 8, - 8, - ck::Sequence<8>, - ck::Sequence<8>>; + ck::tensor_operation::device::DeviceElementwise2dImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // Elementwise op + 3, // NumDim_M + 1, // NumDim_N + 1, // MPerThread + 1, // NPerThread + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq template void host_elementwise4D(HostTensorB& B_nhwc, @@ -53,12 +56,6 @@ int main() const int H = 32; const int W = 1024; - /**const int N = 120; - const int H = 32; - const int W = 64; - - const int C = 128;**/ - std::vector nchw = {N, C, H, W}; std::vector nhwc = {N, H, W, C}; @@ -71,7 +68,6 @@ int main() DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); a_device_buf.ToDevice(a.mData.data()); - // LogRangeAsType(std::cout << "Tensor a : ", a.mData, ",") << std::endl; std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; @@ -115,13 +111,10 @@ int main() if(do_verification) { b_device_buf.FromDevice(b.mData.data()); - // LogRangeAsType(std::cout << "Tensor b : ", b.mData, ",") << std::endl; Tensor host_b(nhwc); host_elementwise4D, Tensor, PassThrough>( host_b, a, nchw, PassThrough{}); - - // LogRangeAsType(std::cout << "Host b : ", host_b.mData, ",") << std::endl; pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9d5fdc0cc70136369e3f396949e2539b1c3b0c8b --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + UnaryOp, // UnaryOp + Scale, // Scalar + 4, // NumDim + 8, // MPerThread + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + FunctorA functor_a, + FunctorB functor_b, + float scale) +{ + std::size_t N = A_nchw.mDesc.GetLengths()[0]; + std::size_t C = A_nchw.mDesc.GetLengths()[1]; + std::size_t H = A_nchw.mDesc.GetLengths()[2]; + std::size_t W = A_nchw.mDesc.GetLengths()[3]; + for(std::size_t w = 0; w < W; ++w) + for(std::size_t h = 0; h < H; ++h) + for(std::size_t c = 0; c < C; ++c) + for(std::size_t n = 0; n < N; ++n) + { + ADataType tmp_val; + auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; + functor_b(tmp_val, a_val); + functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], + scale * tmp_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + Tensor a(nchw); + Tensor b(nhwc); + float scale = 1.f; + auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); + for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) + for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) + for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) + for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) + { + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + + (h * nchw[3]) + w] = i; + i = dis(gen); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + PassThrough{}, + UnaryOp{}, + Scale{scale}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(nhwc); + host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d215cef246bd8505aa61599dddf5a0988845a1e --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + UnaryOp, // UnaryOp + Scale, // Scalar + 4, // NumDim + 8, // MPerThread + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + FunctorA functor_a, + FunctorB functor_b, + float scale) +{ + for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) + for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) + for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) + for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) + { + ADataType tmp_val; + auto a_val = A_nchw(n, c, h, w); + functor_b(tmp_val, a_val); + functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::vector nhwc = {16, 32, 64, 128}; + Tensor a(nchw); + Tensor b(nhwc); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + std::array b_strides = {static_cast(nhwc[1] * nhwc[2] * nhwc[3]), + 1, + static_cast(nhwc[2] * nhwc[3]), + static_cast(nhwc[3])}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + PassThrough{}, + UnaryOp{}, + Scale{scale}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(nhwc); + host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69e411c59ab2c3bbc959f816380290bd4df4cf79 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F32; +using BDataType = F32; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + UnaryOp, // UnaryOp + Scale, // Scalar + 4, // NumDim + 1, // MPerThread + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + FunctorA functor_a, + FunctorB functor_b, + float scale) +{ + std::size_t N = A_nchw.mDesc.GetLengths()[0]; + std::size_t C = A_nchw.mDesc.GetLengths()[1]; + std::size_t H = A_nchw.mDesc.GetLengths()[2]; + std::size_t W = A_nchw.mDesc.GetLengths()[3]; + for(std::size_t w = 0; w < W; ++w) + for(std::size_t h = 0; h < H; ++h) + for(std::size_t c = 0; c < C; ++c) + for(std::size_t n = 0; n < N; ++n) + { + ADataType tmp_val; + auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; + functor_b(tmp_val, a_val); + functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], + scale * tmp_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {5, 4, 2, 3}; + std::vector nhwc = {5, 2, 3, 4}; + Tensor a(nchw); + Tensor b(nhwc); + + float scale = 1.f; + auto i = 0; + std::mt19937 gen(11939); + std::uniform_int_distribution dis(0, 1); + for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) + for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) + for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) + for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) + { + a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + + (h * nchw[3]) + w] = i; + i = dis(gen); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + PassThrough{}, + UnaryOp{}, + Scale{scale}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(nhwc); + host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69f40fe1650bca6bf5d0ca741c61b2484a3ed927 --- /dev/null +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F32; +using BDataType = F32; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = + ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // ElementwiseOp + UnaryOp, // UnaryOp + Scale, // Scalar + 4, // NumDim + 8, // MPerThread + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, + const HostTensorA& A_nchw, + FunctorA functor_a, + FunctorB functor_b, + float scale) +{ + for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) + for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) + for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) + for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) + { + ADataType tmp_val; + auto a_val = A_nchw(n, c, h, w); + functor_b(tmp_val, a_val); + functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::vector nhwc = {16, 32, 64, 128}; + Tensor a(nchw); + Tensor b(nhwc); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + std::array ab_lengths; + std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + std::array b_strides = {static_cast(nhwc[1] * nhwc[2] * nhwc[3]), + 1, + static_cast(nhwc[2] * nhwc[3]), + static_cast(nhwc[3])}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + PassThrough{}, + UnaryOp{}, + Scale{scale}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A (nchw): " << a.mDesc << std::endl; + std::cout << "B (nhwc): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(nhwc); + host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 76361f87a5b58071b149cab505f7c0485abfadab..c02d540983ed17f14f3376ab158308147337fc11 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -167,20 +167,31 @@ int main() XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + YElementwiseOperation{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index cf7d81f895b82244a0ae7ac127e9df4c9a03324f..bfe057e8da677f30844a742b032522a2a61a947c 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) - endif() - add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) +add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) diff --git a/example/48_pool3d_fwd/CMakeLists.txt b/example/48_pool3d_fwd/CMakeLists.txt index f821f2c97b0e7f611d500c07ff1297f70ddd6251..492cb4d37ec8e5e851106f6611e09e5ff9d54d7a 100644 --- a/example/48_pool3d_fwd/CMakeLists.txt +++ b/example/48_pool3d_fwd/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) -endif() +add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 39032fa12378aa1ac5a00783384f130246b5dd93..788f38ec52d06acf2f4cf891cf29191d7948a39d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -32,6 +32,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Pool3d_fwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template ; // InSrcOutDstVectorSize using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: - DeviceIndexPoolBwdImpl; + DeviceMaxPoolBwdImpl; const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; const ck::index_t Xs = (X - 1) * window_dilation_w + 1; @@ -155,7 +155,8 @@ bool maxpool_bwd_test(bool do_verification, dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), window_spatial_lengths, - window_strides); + window_strides, + window_dilations); if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) { diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt index eca410008fde33b6236c2c14c10650a74640c53d..1b0020ebcf65a87ad39b004c149385cdd7de89ae 100644 --- a/example/50_put_element/CMakeLists.txt +++ b/example/50_put_element/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_put_element_fp16 put_element_fp16.cpp) -endif() +add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp index 394f046b1e9d2136e42001f366d7423f1f5e8600..032828f7bcd47889502a619cb273c302afdc9bfc 100644 --- a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp @@ -26,6 +26,8 @@ std::vector f_tensor_strides_ncdhw(ck::index_t N_, return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; else if constexpr(ck::is_same::value) return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return {0, 0, 0, 0, 0}; }; template @@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, return HostTensorDescriptor({N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); } + throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); }; template , 1>; +// clang-format on + +bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + const auto out_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), gemm_g_m_k_strides); + copy(out_desc.GetStrides(), image_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{1, 2}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto col2img = DeviceColToImgInstance{}; + auto invoker = col2img.MakeInvoker(); + auto argument = col2img.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!col2img.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_column_to_image = ck::tensor_operation::host:: + ReferenceColumnToImage(); + + auto ref_invoker = ref_column_to_image.MakeInvoker(); + + auto ref_argument = ref_column_to_image.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_column_to_image.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(out_device.mData.data()); + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunColumnToImageExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunColumnToImage(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunColumnToImageExample(argc, argv); } diff --git a/example/52_im2col_col2im/common.hpp b/example/52_im2col_col2im/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..61d30c4cb44abfcdbebffd88aebb27d9783f88c5 --- /dev/null +++ b/example/52_im2col_col2im/common.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp" + +template +using S = ck::Sequence; + +static inline constexpr ck::index_t NDimSpatial = 2; + +using FP32 = float; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/example/52_im2col_col2im/image_to_column_f32.cpp b/example/52_im2col_col2im/image_to_column_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..140bdf8062d1ad91d02f84b4056d4b446011e632 --- /dev/null +++ b/example/52_im2col_col2im/image_to_column_f32.cpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = FP32; +using OutDataType = FP32; + +using ImLayout = ck::tensor_layout::convolution::GNHWC; +using ImageToColumnOp = ck::conv_tensor_rearrange_op::ImageToColumn; + +// clang-format off +using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl + //#####################| Num| ImLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, ImLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>; +// clang-format on + +bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), image_g_n_c_wis_strides); + copy(out_desc.GetStrides(), gemm_g_m_k_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto img2col = DeviceImgToColInstance{}; + auto invoker = img2col.MakeInvoker(); + auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!img2col.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_image_to_column = ck::tensor_operation::host:: + ReferenceImageToColumn(); + + auto ref_invoker = ref_image_to_column.MakeInvoker(); + + auto ref_argument = ref_image_to_column.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_image_to_column.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunImageToColumnExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunImageToColumn(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); } diff --git a/example/53_layernorm2d_bwd/CMakeLists.txt b/example/53_layernorm2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a58b1109f7c409529486ee0ab62cacea7696843f --- /dev/null +++ b/example/53_layernorm2d_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_layernorm2d_bwd_fp32 layernorm2d_bwd_fp32.cpp) diff --git a/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0b0a3e72ad7c57e8842f4d68459b3918ffea6ae1 --- /dev/null +++ b/example/53_layernorm2d_bwd/layernorm2d_bwd_fp32.cpp @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; +using ComputeDataType = float; + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +// Layernorm: + +// Input shape +// dy: [M, N] +// x: [M, N] +// mean: [M, 1] +// inv_std: [M, 1] + +// Output shape +// dx: [M, N] +// dgamma: [1, N] +// dbeta: [1, N] + +// dgamma = reduce_sum(dy * (x - mean) * inv_std, axis=0) +// dbeta = reduce_sum(dy, axis=0) + +// [CAUSION] +// In DeviceNormalizationBwdDataImpl & DeviceNormalizationBwdGammaBetaImpl, M is Invariant +// dimension, K is reduced dimension Hence, M in this example and +// DeviceNormalizationBwdGammaBetaImpl is different +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 4, // MThreadSliceSize + 1, // KThreadSliceSize + false, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + false, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 512; + + Tensor dy({M, N}); + Tensor x({M, N}); + Tensor gamma({N}); + Tensor mean({M}); + Tensor inv_std({M}); + + Tensor dgamma({N}); + Tensor dbeta({N}); + Tensor dx({M, N}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({M, N}, // lengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {0, 1}, // gammaStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N, 1}, // dxStrides + {1}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({M, N}, // inLengths + {N, 1}, // dyStrides + {N, 1}, // xStrides + {1, 0}, // meanStrides + {1, 0}, // invStdStrides + {N}, // outLengths + {1}, // dgammaStrides + {1}, // dbetaStrides + {0}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({N}); + Tensor host_dbeta({N}); + Tensor host_dx({M, N}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernormBwd; + + ReferenceInstance ref; + auto ref_argument = + ref.MakeArgument(dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {M, N}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/example/54_groupnorm_bwd/CMakeLists.txt b/example/54_groupnorm_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..2cb103499c212bc03455464049ff6005d8f19878 --- /dev/null +++ b/example/54_groupnorm_bwd/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_groupnorm_bwd_fp32 groupnorm_bwd_fp32.cpp) diff --git a/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6cf1b2ff91e6e607f54848f5bbd27fb28a39e8e7 --- /dev/null +++ b/example/54_groupnorm_bwd/groupnorm_bwd_fp32.cpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm_bwd.hpp" + +using DYDataType = float; +using XDataType = float; +using GammaDataType = float; +using MeanInvStdDataType = float; +using DGammaDataType = float; +using DBetaDataType = float; +using DXDataType = float; +using ComputeDataType = float; + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +// Grouprnorm +// kernel 1: M , K +// dy: N, H, W, G, C -> N * G, H * W * C +// x: N, H, W, G, C -> N * G, H * W * C +// gamma: 1, 1, 1, G, C -> 1 * G, 1 * 1 * C +// mean: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 +// rstd: N, 1, 1, G, 1 -> N * G, 1 * 1 * 1 + +// dx: N, H, W, G, C -> N * G, H * W * C + +using XDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdDataImpl< + DYDataType, + XDataType, + GammaDataType, + MeanInvStdDataType, + ComputeDataType, + DXDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // MThreadClusterSize + 32, // KThreadClusterSize + 1, // MThreadSliceSize + 4, // KThreadSliceSize + true, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + true, // IsXFastestDimReduced + 4, // XSrcVectorSize + true, // IsGammaFastestDimReduced + 4, // GammaSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + true, // IsDXFastestDimReduced + 4>; // DXDstVectorSize + +// kernel 2: M , K +// dy: N, H, W, G, C -> G * C, N * H * W +// x: N, H, W, G, C -> G * C, N * H * W +// mean: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 +// rstd: N, 1, 1, G, 1 -> G * 1, N * 1 * 1 + +// dgamma: 1, 1, 1, G, C -> G * C +// dbeta: 1, 1, 1, G, C -> G * C + +// reduced axis: 0, 1, 2 + +using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaImpl< + DYDataType, + XDataType, + MeanInvStdDataType, + ComputeDataType, + DGammaDataType, + DBetaDataType, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterInvariant + 32, // ClusterReduce + 4, // SliceInvariant + 1, // SliceReduce + false, // IsDYFastestDimReduced + 4, // DYSrcVectorSize + false, // IsXFastestDimReduced + 4, // XSrcVectorSize + false, // IsMeanInvStdFastestDimReduced + 1, // MeanInvStdSrcVectorSize + 4, // DGammaDstVectorSize + 4>; // DBetaDstVectorSize + +int main() +{ + bool time_kernel = false; + + ck::index_t N = 16; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t G = 32; + ck::index_t C = 64; + + Tensor dy({N, H, W, G, C}); + Tensor x({N, H, W, G, C}); + Tensor gamma({G, C}); + Tensor mean({N, G}); + Tensor inv_std({N, G}); + + Tensor dgamma({G, C}); + Tensor dbeta({G, C}); + Tensor dx({N, H, W, G, C}); + + dy.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + mean.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + inv_std.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dy_dev(sizeof(DYDataType) * dy.mDesc.GetElementSpaceSize()); + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem mean_dev(sizeof(MeanInvStdDataType) * mean.mDesc.GetElementSpaceSize()); + DeviceMem inv_std_dev(sizeof(MeanInvStdDataType) * inv_std.mDesc.GetElementSpaceSize()); + DeviceMem dx_dev(sizeof(DXDataType) * dx.mDesc.GetElementSpaceSize()); + DeviceMem dgamma_dev(sizeof(DGammaDataType) * dgamma.mDesc.GetElementSpaceSize()); + DeviceMem dbeta_dev(sizeof(DBetaDataType) * dbeta.mDesc.GetElementSpaceSize()); + + dy_dev.ToDevice(dy.mData.data()); + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + mean_dev.ToDevice(mean.mData.data()); + inv_std_dev.ToDevice(inv_std.mData.data()); + + std::vector dyStrides{dy.mDesc.GetStrides().begin(), dy.mDesc.GetStrides().end()}; + std::vector xStrides{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}; + std::vector gammaStrides = {0, 0, 0, C, 1}; + std::vector meanStrides = {G, 0, 0, 1, 0}; + std::vector invStdStrides = {G, 0, 0, 1, 0}; + std::vector dxStrides{dx.mDesc.GetStrides().begin(), dx.mDesc.GetStrides().end()}; + + // backward x + auto x_device_instance = XDeviceInstance{}; + + auto x_argument_ptr = x_device_instance.MakeArgumentPointer({N, H, W, G, C}, // lengths + dyStrides, // dyStrides + xStrides, // xStrides + gammaStrides, // gammaStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + dxStrides, // dxStrides + {1, 2, 4}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer()); + + if(!x_device_instance.IsSupportedArgument(x_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto x_invoker_ptr = x_device_instance.MakeInvokerPointer(); + x_invoker_ptr->Run(x_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // backward gamma & beta + + auto gamma_beta_device_instance = GammaBetaDeviceInstance{}; + auto gamma_beta_argument_ptr = + gamma_beta_device_instance.MakeArgumentPointer({N, H, W, G, C}, // inLengths + dyStrides, // dyStrides + xStrides, // xStrides + meanStrides, // meanStrides + invStdStrides, // invStdStrides + {G, C}, // outLengths + {C, 1}, // dgammaStrides + {C, 1}, // dbetaStrides + {0, 1, 2}, // reduceDims + dy_dev.GetDeviceBuffer(), + x_dev.GetDeviceBuffer(), + mean_dev.GetDeviceBuffer(), + inv_std_dev.GetDeviceBuffer(), + dgamma_dev.GetDeviceBuffer(), + dbeta_dev.GetDeviceBuffer()); + + if(!gamma_beta_device_instance.IsSupportedArgument(gamma_beta_argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported." << __FILE__ << ":" << __LINE__ + << std::endl; + return 1; + }; + + auto gamma_beta_invoker_ptr = gamma_beta_device_instance.MakeInvokerPointer(); + gamma_beta_invoker_ptr->Run(gamma_beta_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_dgamma({G, C}); + Tensor host_dbeta({G, C}); + Tensor host_dx({N, H, W, G, C}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnormBwd; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument( + dy, x, gamma, mean, inv_std, host_dgamma, host_dbeta, host_dx, {N, H, W, G, C}); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + dgamma_dev.FromDevice(dgamma.mData.data()); + dbeta_dev.FromDevice(dbeta.mData.data()); + dx_dev.FromDevice(dx.mData.data()); + + pass &= ck::utils::check_err(dgamma, host_dgamma, "Error: Incorrect dgamma", 1e-3, 1e-3); + pass &= ck::utils::check_err(dbeta, host_dbeta, "Error: Incorrect dbeta", 1e-3, 1e-3); + pass &= ck::utils::check_err(dx, host_dx, "Error: Incorrect dx", 1e-3, 1e-3); + } + + return (pass ? 0 : 1); +} diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..57bc0b33ef0aaf22cd9613e9612bf6a2ae593c51 --- /dev/null +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -0,0 +1,8 @@ +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13ba6398140793627420271257ae9ca061c9302b --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..42500b64e69fb8345002498b281e7fe5f71fc575 --- /dev/null +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -0,0 +1,8 @@ +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4317484336e1d023f1534fb11c482afebab7b5e1 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F16; +using A1DataType = F32; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; +using ComputeDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + float alpha = 1.0f; + float beta = 1.0f; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 0, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 1>{b_ns_ks_lengths}, + std::array, 1>{b_ns_ks_strides}, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..6eaddd3ff7cc27a9f45c55102fa01a09ae67aeda --- /dev/null +++ b/example/62_convnd_activ/CMakeLists.txt @@ -0,0 +1,17 @@ +add_subdirectory(binary) +add_subdirectory(multi_AB) +add_subdirectory(unary) + +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl) + # ScaleAdd ScaleAdd Relu + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) + add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7c07b6bca6244eeef9662511e8ebc3d36b7d5b08 --- /dev/null +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -0,0 +1,13 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_binary_xdl) + # Bilinear residual + add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16) + add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5bddf2302ea8134d1c5dc361910171d5e44263d --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceGroupedConvNDBwdDataInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + NDimSpatial, + OutLayout, + WeiLayout, + ck::Tuple, + InLayout, + OutDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp, + ConvSpec, // ConvForwardSpecialization + true, + true, + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 2, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_host.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in_host.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // Initialize based on out_host + Tensor in_device(in_host); + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + in_device_buf.ToDevice(in_device.mData.data()); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {in_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_c_wis_lengths}, + std::array, NumDs>{e_g_n_c_wis_strides}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetInputByte() / sizeof(InDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {in_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device.mData, in_host.mData); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae1ebcb2cd64d0da533978995863e2f9ce80e6a5 --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + // Initialize based on out_host + Tensor out_device(out_host); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out_device.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {out_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_k_wos_lengths}, + std::array, NumDs>{e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {out_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d101fd59bd9281a81f46c7b28868f3ad07626de0 --- /dev/null +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using BiasLayout = ck::tensor_layout::convolution::G_K; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0]; + const ck::index_t K = out_g_n_k_wos_desc.GetLengths()[2]; + + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_g_k_lengths; + std::array bias_g_k_strides; + // Fill other lenghts than G,K with 1 and strides with 0 + bias_g_k_lengths.fill(1); + bias_g_k_strides.fill(0); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + + // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(broadcasted_bias_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + std::cout << "z_tensor: " << d_tensors[0].mDesc << std::endl; + std::cout << "bias_tensor: " << d_tensors[1].mDesc << std::endl; + + // Make sure that we allocated only G * K values for bias + assert(static_cast(d_tensors[1].mData.size()) == G * K); + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem z_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem bias_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + z_buf.ToDevice(d_tensors[0].mData.data()); + bias_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {z_buf.GetDeviceBuffer(), bias_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, bias_g_k_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, bias_g_k_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + G * K + + conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + G * K * sizeof(OutDataType) + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f784655cc57da88f5e67eeb5dbcdf440b8b7d35f --- /dev/null +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::ScaleAddScaleAddRelu; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 2; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + std::array, NumDs> d_tensors = {Tensor(out_g_n_k_wos_desc), + Tensor(out_g_n_k_wos_desc)}; + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[0].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + d_tensors[1].GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem d0_buf(sizeof(OutDataType) * d_tensors[0].mDesc.GetElementSpaceSize()); + DeviceMem d1_buf(sizeof(OutDataType) * d_tensors[1].mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + d0_buf.ToDevice(d_tensors[0].mData.data()); + d1_buf.ToDevice(d_tensors[1].mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + const std::array ds = {d0_buf.GetDeviceBuffer(), d1_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{ + e_g_n_k_wos_lengths, e_g_n_k_wos_lengths}, + std::array, NumDs>{ + e_g_n_k_wos_strides, e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 2 * conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + 2 * conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..c89c82d384c68a9676025421d04d542e28725710 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -0,0 +1,17 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_multi_ab_xdl) + # ScaleAdd on A and B + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 conv_fwd_xdl_scaleadd_ab_fp32.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp32) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scaleadd_ab_bf16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b7ceee03b829a943a657a1c2090f1a9fc6bd3c44 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::bhalf_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..08d8a89669cc85b79882d5f8a5339eba6f4c76ad --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = ck::half_t; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bef9980b3e3b1604efec00d1dcef50622b9dee73 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = float; +using AccDataType = float; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b132b91213e9eb0237a21b63ba9e80f49e3c400 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_multi_ab_common.hpp" + +using DataType = int8_t; +using AccDataType = int32_t; +using InDataType = DataType; +using WeiDataType = DataType; +using OutDataType = DataType; +using ADataTypes = ck::Tuple; +using BDataTypes = ck::Tuple; + +using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2626843ed4b933d66c27213c324e90fdd75a10f6 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDMultiABFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataTypes, + WeiDataTypes, + AccDataType, + DataType, + ck::Tuple<>, + DataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +namespace { +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumAs = 2; + constexpr ck::index_t NumBs = 2; + Tensor in(in_g_n_c_wis_desc); + Tensor in_bias(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor wei_bias(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei_bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + in_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + wei_bias.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem in_bias_device_buf(sizeof(InDataType) * in_bias.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem wei_bias_device_buf(sizeof(WeiDataType) * wei_bias.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + in_bias_device_buf.ToDevice(in_bias.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + wei_bias_device_buf.ToDevice(wei_bias.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + std::array as{in_device_buf.GetDeviceBuffer(), + in_bias_device_buf.GetDeviceBuffer()}; + std::array bs{wei_device_buf.GetDeviceBuffer(), + wei_bias_device_buf.GetDeviceBuffer()}; + std::array ds{}; + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(as, + bs, + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {}, + {}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops() + + 2 * conv_param.GetOutputByte() / sizeof(InDataType) + + 2 * conv_param.GetOutputByte() / sizeof(WeiDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetInputByte() + + conv_param.GetWeightByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + const std::array, NumAs - 1> elementwise_a_tensors = {in_bias}; + const std::array, NumBs - 1> elementwise_b_tensors = {wei_bias}; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + elementwise_a_tensors, + elementwise_b_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace diff --git a/example/62_convnd_activ/run_convnd_activ_example.inc b/example/62_convnd_activ/run_convnd_activ_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..5a402e41cdeb1ed8a7655b2a0580391ed5cc9226 --- /dev/null +++ b/example/62_convnd_activ/run_convnd_activ_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +bool run_convnd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 2, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..94ffb3661ceb2787b292fa6549c18391d2389c52 --- /dev/null +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -0,0 +1,35 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_unary_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4669465bf440077577202a54807e32763bc04980 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e621c3b15e3178969ef2fc3fb44203c1efb6032f --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bb26fb292cd8282951864b62c46846a9a0c7e545 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1aa4d5d4fa224b5857c19491317bce14d24d6dea --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Elu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..659c36ec8ddee15844b4fd2bf987c55eb725e2ca --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5efa0f8f9c76ac6062a699166fe7c2d8c104e192 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Power; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b7c598eecac8e4ddf3964172921c2822f07395 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Relu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53e06d387f93127448df44dab9e4c8090805e434 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a2d76da4e214f44fc20da929a0ceea52cb6ceba3 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d60c005e0937cc9120cfddd67c52cbfa39ab1898 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::TanH; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/63_layernorm4d_fwd/CMakeLists.txt b/example/63_layernorm4d_fwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..3f8c679ab82c8edaa3a96f2dcf5afded77fabe8d --- /dev/null +++ b/example/63_layernorm4d_fwd/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_layernorm4d_fwd_fp16 layernorm4d_fwd_fp16.cpp) +add_example_executable(example_layernorm4d_fwd_splitk_fp16 layernorm4d_fwd_splitk_fp16.cpp) diff --git a/example/63_layernorm4d_fwd/common.hpp b/example/63_layernorm4d_fwd/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6c7b99f89f9930cdcb84bda674f96ab3a1237241 --- /dev/null +++ b/example/63_layernorm4d_fwd/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp b/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..659cc3554d2a7f2c28f3d55bbb8a6ac65fe1c173 --- /dev/null +++ b/example/63_layernorm4d_fwd/layernorm4d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationFwdImpl; // SaveMeanInvStdScalarPerVector +#include "run_layernorm4d_fwd_example.inc" + +int main() { return run_layernorm4d_fwd_example(); } diff --git a/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp b/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..415635a0e32d2c0777cbe87d0490863fd4a80e89 --- /dev/null +++ b/example/63_layernorm4d_fwd/layernorm4d_fwd_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 4; +constexpr int NumReduceDim = 3; + +using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationFwdSplitKImpl< + XDataType, + GammaDataType, + BetaDataType, + ComputeDataType, + YDataType, + SaveMeanInvStdDataType, + PassThrough, + Rank, + NumReduceDim, + 256, // BlockSize + 8, // ClusterM + 32, // ClusterK + 1, // SliceM + 8, // SliceK + 1, // XYVectorDim (0=M, 1=K) + 8, // XScalarPerVector + 1, // GammaVecDim (0=M, 1=K) + 8, // GammaScalarPerVector + 1, // BetaVecDim (0=M, 1=K) + 8, // BetaScalarPerVector + 8, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector + +#include "run_layernorm4d_fwd_example.inc" + +int main() { return run_layernorm4d_fwd_example(); } diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..1a0b558e2ca02ec61be7c3988b57e3dd93fb3040 --- /dev/null +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_layernorm4d_fwd_example() +{ + bool time_kernel = false; + + ck::index_t N = 256; + ck::index_t H = 16; + ck::index_t W = 16; + ck::index_t C = 8; + + Tensor x({N, H, W, C}); + Tensor gamma({H, W, C}); + Tensor beta({H, W, C}); + Tensor y({N, H, W, C}); + Tensor save_mean({N}); + Tensor save_inv_std({N}); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {N, H, W, C}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, W * C, C, 1}, + {0, W * C, C, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1, 2, 3}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y({N, H, W, C}); + Tensor host_save_mean({N}); + Tensor host_save_inv_std({N}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {N, H, W, C}, + {1, 2, 3}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1fdd2f6d14826776b734b76815e192b4654e5b09..c19ba93b69ff635daea2b8d8008971b1b03a0402 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -7,20 +7,112 @@ add_custom_target(examples) function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - add_dependencies(examples ${EXAMPLE_NAME}) - add_dependencies(check ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(result EQUAL 0) + add_dependencies(${EXAMPLE_NAME} ${FILE_NAME}) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_dependencies(examples ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 94ba97c6e1404918fe805414e41c339a2010e03d..16afcf320dff1a5759238a63879e8bde4a49a490 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -45,16 +45,30 @@ #define CK_USE_WAVES_PER_EU 0 #endif +// define general macros for various architectures +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif +#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) +#define __gfx101__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif + // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code + defined(__gfx90a__) || defined(__gfx94__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 -#elif defined(__gfx1030__) // for GPU code +#elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 #endif @@ -62,33 +76,36 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ + defined(__gfx94__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 +#elif defined(__gfx11__) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #endif // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(defined(__gfx90a__) || defined(__gfx94__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) #define CK_USE_AMD_MFMA_GFX940 #endif // WMMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_WMMA -#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#elif defined(__gfx11__) // for GPU code #define CK_USE_AMD_WMMA #endif @@ -104,15 +121,13 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ - defined(__gfx942__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -131,6 +146,12 @@ // inner product using V_DOT with DPP8 modifiers #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 +// LDS direct loads using inline assembly +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 + +// set stochastic rounding as default for f8 conversions +#define CK_USE_SR_F8_CONVERSION 1 + // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -209,7 +230,7 @@ // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 -#elif +#else // enable only on MI200 #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // CK_WORKAROUND_DENORM_FIX diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 13dc5da5d179443b603cce5a2aa6274061856db3..1748344756771a2d57dfd044d79f8c72f26f8fb4 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -43,6 +43,9 @@ #ifndef CK_ENABLE_FP8 #define CK_ENABLE_FP8 "ON" #endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif #ifndef CK_ENABLE_FP16 #define CK_ENABLE_FP16 "ON" #endif @@ -66,6 +69,10 @@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #endif +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + #ifndef CK_ENABLE_FP16 #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #endif diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 24b3a50d29977907338a06f5e7f33873f690635b..b470527a070a3c51446c14f364eafbd97acca267 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -27,7 +27,7 @@ inline std::string get_device_name() } const std::string raw_name(props.gcnArchName); - // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 static std::map device_name_map = { {"Ellesmere", "gfx803"}, {"Baffin", "gfx803"}, @@ -59,5 +59,31 @@ inline bool is_xdl_supported() ck::get_device_name() == "gfx942"; } +inline bool is_lds_direct_load_supported() +{ + // Check if direct loads from global memory to LDS are supported. + return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; +} + +inline bool is_navi1_supported() +{ + return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || + ck::get_device_name() == "gfx1012"; +} + +inline bool is_navi2_supported() +{ + return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || + ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || + ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; +} + +inline bool is_navi3_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; +} + } // namespace ck #endif diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index af7bebd9d6afbf59f1852156b1708ed843d9a9f6..c0894f1d70e439e51e190905f79d3dcd31637c58 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -3,15 +3,32 @@ #pragma once +#include #include +// To be removed, which really does not tell the location of failed HIP functional call inline void hip_check_error(hipError_t x) { if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ - << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " + << "hip_check_error.hpp" + << ": " << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } + +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" \ + << "hip_check_error.hpp" \ + << "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while(0) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index c3336be6c84c6c0d591665b5310365d39fdd9f7f..ba21600efae4491f22a2970e21c3577068804eb8 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -30,12 +30,16 @@ float launch_and_time_kernel(const StreamConfig& stream_config, block_dim.y, block_dim.z); - printf("Warm up 1 time\n"); + printf("Warm up %d times\n", stream_config.cold_niters_); #endif // warm up - kernel<<>>(args...); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } - const int nrepeat = 10; + const int nrepeat = stream_config.nrepeat_; #if DEBUG_LOG printf("Start running %d times...\n", nrepeat); #endif @@ -50,6 +54,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, for(int i = 0; i < nrepeat; ++i) { kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -64,11 +69,13 @@ float launch_and_time_kernel(const StreamConfig& stream_config, else { kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif @@ -96,13 +103,17 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, block_dim.y, block_dim.z); - printf("Warm up 1 time\n"); + printf("Warm up %d times\n", stream_config.cold_niters_); #endif // warm up preprocess(); - kernel<<>>(args...); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } - const int nrepeat = 10; + const int nrepeat = stream_config.nrepeat_; #if DEBUG_LOG printf("Start running %d times...\n", nrepeat); #endif @@ -118,6 +129,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -133,11 +145,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, { preprocess(); kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 505a602b240428bcd1f0f81018fef0c4716c80b7..a5b14073056feb4670e7f8aa044824487491108b 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -11,4 +11,6 @@ struct StreamConfig hipStream_t stream_id_ = nullptr; bool time_kernel_ = false; int log_level_ = 0; + int cold_niters_ = 5; + int nrepeat_ = 50; }; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp deleted file mode 100644 index e527509f57af60ef9065ccdbf8c2112d7ed336af..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,370 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp" - -namespace ck { - -/** - * DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit - * the data loaded from LDS to registers. - * - * The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C - * between them in such a way that threads from the same group need the same chunk of either - * matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the - * whole chunk from LDS to its own register space. - * Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size` - * of the chunk, and then share that data with other threads from the same lane group. - * - * Assumptions coming from the usage of DPP8: - * 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or - * `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` - - * - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either - * matrix A or B; - * - based on these values we determine which matrix to share. - * 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or - * `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) - - * - we have to make sure that the data to split is divisible by the number of - * threads in the group. - * - * General algorithm: - * C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] - * A and B are visible to the whole block, C is distributed among each thread - * Assume: - * 1. A: - * 1. ABlockDesc_BK0_BM_BK1 is known at compile-time - * 2. ABlockBuffer is DynamicBuffer - * 2. B: - * 1. BBlockDesc_BK0_BN_BK1 is known at compile-time - * 2. BBlockBuffer is DynamicBuffer - * 3. C: - * 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time - * 2. CThreadBuffer is StaticBuffer - * 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 - */ -template - typename BM10BN10ThreadClusterBN10Xs, // Sequence - index_t AThreadCopyScalarPerVector_BM11, - index_t BThreadCopyScalarPerVector_BN11, - typename enable_if::type = false> -struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0 -{ - using AIndex = MultiIndex<4>; - using BIndex = MultiIndex<4>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); - static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); - static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); - static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - - static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; - static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; - - static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; - static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; - - static constexpr index_t BM11 = BM1PerThreadBM11; - static constexpr index_t BN11 = BN1PerThreadBN11; - - static constexpr index_t BM1 = BM100 * BM101 * BM11; - static constexpr index_t BN1 = BN100 * BN101 * BN11; - - static constexpr index_t BM0 = BM / BM1; - static constexpr index_t BN0 = BN / BN1; - - // We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all - // threads in a lane group need the same chunk of B or A matrices and we can share them using - // DPP. - static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size); - static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false; - static constexpr bool ShareA = !ShareB; - - // If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`, - // respectively) elements, so we split them between threads in lane group so each thread loads - // less data from LDS. - static constexpr index_t BM1PerThread = - ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11; - static constexpr index_t BN1PerThread = - ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11; - - __host__ __device__ static constexpr auto - MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) - { - const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( - a_block_desc_bk0_bm_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_block_bk0_bm0_bm1_bk1; - } - - __host__ __device__ static constexpr auto - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) - { - const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( - b_block_desc_bk0_bn_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_block_desc_bk0_bn0_bn1_bk1; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM, BN] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM0, BM1, BN0, BN1] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; - } - - __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() - { - return Sequence{}; - } - - static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = - MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); - - static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); - - public: - __device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0() - : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id())}, - a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()}, - b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()} - { - static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && - BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); - - static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == - BBlockDesc_BK0_BN_BK1{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && - BM10BN10ThreadClusterBN10Xs::Size() == 2, - "wrong!"); - } - - __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) - { - // lower: [BM0, BM1, BN0, BN1] - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - constexpr auto adaptor0 = - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); - - // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // upper: [Tid, BM0, BM11, BN0, BN11] - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), - make_pass_through_transform(BM0), - make_pass_through_transform(BM11), - make_pass_through_transform(BN0), - make_pass_through_transform(BN11)), - make_tuple( - Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - - return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); - } - - __device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1() - { - const auto offsetBM0 = c_thread_origin_data_idx_[I0]; - // If sharing matrix A, we need a separate BM1 offset for each thread in lane group. - const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] + - dpp8::get_thread_idx_in_lane_group() * BM1PerThread - : c_thread_origin_data_idx_[I1]; - return make_tuple(0, offsetBM0, offsetBM1, 0); - } - - __device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1() - { - const auto offsetBN0 = c_thread_origin_data_idx_[I2]; - // If sharing matrix B, we need a separate BN1 offset for each thread in lane group. - const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] + - dpp8::get_thread_idx_in_lane_group() * BN1PerThread - : c_thread_origin_data_idx_[I3]; - return make_tuple(0, offsetBN0, offsetBN1, 0); - } - - template - __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, - const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - auto a_thread_buf = make_static_buffer( - a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); - - constexpr auto threadwise_contraction = - ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< - FloatA, - FloatB, - FloatC, - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - CThreadDesc_BM0_BM11_BN0_BN11, - Sequence, - Sequence<1, BM1PerThreadBM11>, - Sequence<1, BN1PerThreadBN11>, - ShareA>{}; - - static_for<0, BN0, 1>{}([&](auto bn0) { - static_for<0, BM0, 1>{}([&](auto bm0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - - static_for{}([&](auto bk0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(bk0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(bk0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - }); - }); - }); - } - - private: - // A[BK0, BM0, BM1, BK1] - static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - // B[BK0, BN0, BN1, BK1] - static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatA, - FloatA, - decltype(a_block_desc_bk0_bm0_bm1_bk1_), - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatB, - FloatB, - decltype(b_block_desc_bk0_bn0_bn1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d62ed4b15dd3a6ea078ef7ed0b70e818a30af611 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7b2aaa76bb48d4785fd6b4cc0b3f001c788fc4fe --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -0,0 +1,999 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +// Double LDS buffer +// Prefetech 2 stage +// Local prefetch 1 stage + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_pipeline_v4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + + // HotLoopInstList::Print(); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __device__ static constexpr auto HotLoopScheduler() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_buffer_load_inst = + HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA + }); + } + + template + __device__ static constexpr auto TailScheduler() + { + } + + template <> + __device__ static constexpr auto TailScheduler<1>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_write_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA + }); + } + + template <> + __device__ static constexpr auto TailScheduler<2>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_read_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA + }); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + // Inst List: + // ds_read_b128: 16 + // ds_write_b128: 8 + // buffer_load_dwordx4: 16 + // v_mfma: 0 + // ------------------------------------------------------------------------------------------- + + // Global prefetch 1th, Fill Ping LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1th, Fill Ping Reg + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 2th, Fill Pong LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3rd + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + // ------------------------------------------------------------------------------------------- + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 2; + } while(i < (num_loop - 3)); + } + + // tail + if constexpr(TailNum == 3) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == 2) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 5ec964bd3fdbda9319d1c6eab0ccb207775a6137..b3d45f3d0c843c5e72a37b4a9a7617f2f1abc16a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); + } + else + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } } protected: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d5a64d7aa6fe6c96f384d3f00af642231b6b2c53..701dd04f6ca19b2de8ed8a630149cd725ac95c3f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -4,27 +4,13 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) @@ -42,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) } template + index_t KPack, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { static constexpr auto I0 = Number<0>{}; @@ -72,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -308,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -332,25 +322,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -370,8 +362,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -380,8 +372,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -399,7 +391,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + KPack, + ComputeTypeA, + ComputeTypeB> { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + KPack, + ComputeTypeA, + ComputeTypeB>; #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING using Base::a_block_desc_m0_m1_m2_k; @@ -454,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -493,20 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -528,8 +531,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { @@ -555,8 +558,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, I1, I1, Number{})); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -565,8 +568,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -582,7 +585,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }; template + LoopScheduler LoopSched, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } }; @@ -632,26 +644,27 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() * 3. configurable k index starting position and step size after each FMA/XDL instruction */ -template {}.K0PerXdlops, - index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -668,7 +681,8 @@ struct BlockwiseGemmXdlops_v2 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a737c9195b436b592cffc89dc272706d80cb8fb4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +/** + * Transfer that uses direct load instructions to copy data from global to LDS memory. + * + * Traditional loads first copy data from global to registers, and then from registers to LDS. + * Direct loads do not need an intermediate step, data is copied directly from global to LDS, + * without the use of additional registers. + * + * However, the instruction has limitations: + * - each thread must copy exactly a single DWORD - 4 bytes; + * - threads within a single wavefront must write consecutive DWORDS into LDS, + * (data in global do not need to be contiguous, each thread might have its own offset). + * + * To make sure that all the transfers finished, the `waitcnt` instruction must be used with + * `vmcnt` instead of `lgkmcnt`. + * + * Limitations of the transfer class: + * - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight; + * - `DstVectorDim` must be the last dimension; + * - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1; + * - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B + * (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16, + * `ScalarPerVector` must be 2); + * - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be + * the same dimension; + * - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64, + * they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way + * to guarantee that. + */ +template +struct ThreadGroupTensorSliceTransfer_DirectLoad +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto block_slice_lengths = BlockSliceLengths{}; + static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + + static constexpr auto thread_single_load_size = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. + // It makes the whole wavefront load contiguous memory, what is required for direct loads. + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; + + static __device__ constexpr bool AreThreadClusterLengthsValid() + { + // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to + // LDS by the threads from a single wavefront. + // Examples (assuming 64 threads in a wavefront, 128 in a thread block): + // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp32 -> ScalarPerVector = 1 + // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of + // [0, 4, 0]. + // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration, + // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs). + // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp16 -> ScalarPerVector = 2 + // NOTE: ThreadClusterLengths must take into account that each thread writes two + // elements (single DWORD) along the contiguous dimension. + // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write + // 8 * 2 elements of K1PerBlock and there are only 8; + // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32 + // writes [1, 0, 0] instead of [0, 8, 0]. + // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the + // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive + // elements = 64 consecutive DWORDs. + int num_contiguous_dwords = 1; + bool is_contiguous = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(is_contiguous) + { + num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1]; + } + if(thread_slice_lengths[nDim - i - 1] > 1) + { + is_contiguous = false; + } + }); + constexpr index_t wavefront_size = get_warp_size(); + const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0; + + bool thread_slice_lengths_correct = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(thread_slice_lengths[i] <= 0) + { + thread_slice_lengths_correct = false; + } + }); + + return wave_contiguous && thread_slice_lengths_correct; + } + + __device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + + { + static_assert(ck::is_same_v, + "Direct load transfer does not support datatypes conversion. Source and " + "destination data types must be the same."); + + static_assert( + DstVectorDim == nDim - 1, + "Direct load transfer requires the destination vector dimension to be the last one."); + + static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, + "When loading more than one element per thread at once, the contiguous " + "dimension must be the same between source and destination."); + + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + static_assert(bytes_per_thread_load == dword_bytes, + "Direct load transfer requires each thread to load exactly a single " + "DWORD of data."); + + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size(), + "Inconsistent number of dimensions across lengths and descriptors."); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "The number of threads cannot be less than the number of elements in " + "thread cluster lengths."); + + static_assert( + AreThreadClusterLengthsValid(), + "Thread cluster lengths are incorrect. They must be set in a way that allows a single " + "wavefront to write contiguous DWORDs into LDS memory. "); + + const auto thread_cluster_idx = + thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + + SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + src_slice_origin_ = src_slice_origin_idx; + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_slice_origin_ = dst_slice_origin_idx; + } + + __device__ void ResetDstSliceWindow(const DstDesc& dst_desc) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + static_assert( + ck::is_same_v, remove_cvref_t>, + "SrcBuffer and SrcData data types must be consistent."); + static_assert( + ck::is_same_v, remove_cvref_t>, + "DstBuffer and DstData data types must be consistent."); + + constexpr auto dst_access_lengths = thread_slice_lengths; + + const auto dst_forward_steps = generate_steps(dst_desc, 1); + const auto dst_backward_steps = generate_steps(dst_desc, -1); + const auto src_forward_steps = generate_steps(src_desc, 1); + const auto src_backward_steps = generate_steps(src_desc, -1); + + // Loop over the destination block and copy data. + static_ford{}([&](auto ordered_dst_access_idx) { + const auto src_offset = src_coord_.GetOffset(); + const auto dst_offset = dst_coord_.GetOffset(); + + // Check if src data is not in the logic padding area. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + src_buf.template DirectCopyToLds, ScalarPerVector>( + dst_buf, src_offset, dst_offset, is_src_valid); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // Decide whether to move forward or backward. + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); + } + } + }); + }); + + // Reset the destination slice since the entire buffer has been already filled. + ResetDstSliceWindow(dst_desc); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + src_slice_origin_ = src_slice_origin_ + step; + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_); + } + + template + __device__ auto generate_steps(const DescType& desc, int sign) + { + return generate_tuple( + [&](auto i) { + Index step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + private: + static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}); + + SrcCoord src_coord_; + DstCoord dst_coord_; + Index src_slice_origin_; + Index dst_slice_origin_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a9bb3213a541c49358035b1b51bbf28aac26d87 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7r2 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs)); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dc08a2c88b2abbd9b26412121ae223af6c3e1c01 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace conv_tensor_rearrange_op { + +struct BaseConvTensorRearrangeOp +{ +}; + +struct ImageToColumn : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Image to Column"; +}; + +struct ColumnToImage : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Column to Image"; +}; + +template ::value, + bool>::type = false> +std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&) +{ + os << Op::name; + return os; +} + +} // namespace conv_tensor_rearrange_op +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index f59021a8a079f9e9f8adcfa7f8c406c8515712c8..08371ed885365f6135e80e1caac944826310b2c8 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -62,7 +62,9 @@ struct BaseOperator }; virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } - virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + virtual void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const { assert(p_arg); p_arg->p_workspace_ = p_workspace; diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3116d404742e1769a1fc6c9b75a1f22f5eea689a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M0, M1, ... K0, K1, ...], ... +// input : B0[N0, N1, ... K0, K1, ...], ... +// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... +// output : E[M0, M1, ... N0, N1, ...] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index 118ade8978643f10dcfafbefb552fe485d6697e4..dffb51d0ba7517446df5c9a3901a10be33b99db8 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -33,7 +33,8 @@ template + typename CDEElementwiseOperation, + typename ComputeDataType = ADataType> struct DeviceContractionMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b28204fd841cd1430e30a255dacf5b9a4583a929 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Convolution Tensor Rearrange. + * + * This Device operator supports converting an image to + * the GEMM representation (Image to Column) and + * converting a GEMM form to the image (Column to Image). + * Supported layouts: + * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C] + * [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C] + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ImageLayout Input Layout. + * \tparam InputDataType Input Data Type. + * \tparam OutputDataType Output Data Type. + * \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage. + */ +template +struct DeviceConvTensorRearrange : public BaseOperator +{ + + /** + * \brief Make argument pointer for image to column. + * + * \param p_in A pointer to the device memory of the input image. + * \param p_out A pointer to the device memory of the output. + * \param G Convolution number of groups. + * \param N Convolution batch size. + * \param C Convolution number of channels. + * \param input_spatial_lengths Input spatial lengths. + * \param filter_spatial_lengths Filter spatial lengths. + * \param output_spatial_lengths Output spatial lengths. + * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. + * \param gemm_g_m_k_strides Gemm form strides. + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Convolution left pads. + * \param input_right_pads Convolution right pads. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_out, + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3b0cbc6e5fffd35ee361856db621f53beccf7dda --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwise : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +template +using DeviceElementwisePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbb9fadc6d6f3f03673f4c5a7f13229f29b5d849 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index 6407aa7e09b81a4dc09dcb1bdca5701aa86d9e14..dfc27e726eaafaddf13ac62c530567eb32ab0ddc 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -20,7 +20,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> struct DeviceGemmSplitK : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, @@ -48,7 +49,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> using DeviceGemmSplitKPtr = std::unique_ptr>; + CElementwiseOperation, + ComputeType>>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 7e4bca2bd66a55c70b6a9365e2f1debcc22bb119..2abf1d5a10c7e8645ca51d842390d50d5fbce5dd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -29,7 +29,9 @@ template + typename CDEElementwiseOperation, + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index ab9e6adb41963df1f9cd56ca2a7ed3056275defe..7296e4faaa30c56cab55ac5140b94a20c45782b1 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -20,7 +20,9 @@ template + typename OutElementwiseOperation, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight : public BaseOperator { virtual std::unique_ptr diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fa3dcfdf207a21cea8f2628df41b016c502d8a30 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +using is_tuple = decltype(std::declval().IsTuple()); + +/** + * \brief Grouped Convolution Forward + * + * \details + * input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]... + * input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]... + * input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... + * output : output image E[G, N, K, Ho, Wo] + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + */ +template ::value, + Number<0>, + ADataType>())> // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed +struct DeviceGroupedConvFwdMultipleABD : public BaseOperator +{ + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + // If DataType is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + + /** + * \brief Make argument pointer for grouped conv fwd. + * + * \param p_a A pointer to the input (std::array with + pointers for multiple A). + * \param p_b A pointer to the weight (std::array with + pointers for multiple B). + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the output. + * \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d). + * \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d). + * \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d). + * \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d). + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Input left paddings. + * \param input_right_pads Input right paddings. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr MakeArgumentPointer( + APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 1cc30fd9e6ad10336a49a261b3ee7c47f9053435..daf397189a4ed91d190a463ef3ec4212837e5abf 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -3,21 +3,33 @@ #pragma once -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" namespace ck { namespace tensor_operation { namespace device { -// Convolution Forward: -// input : input image A[G, N, C, Hi, Wi], -// input : weight B[G, K, C, Y, X], -// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... -// output : output image E[G, N, K, Ho, Wo] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) +/** + * \brief Grouped Convolution Forward + * + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceGroupedConvFwdMultipleABD. + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + */ template -struct DeviceGroupedConvFwdMultipleD : public BaseOperator -{ - static constexpr index_t NumDTensor = DsDataType::Size(); - - static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); - - virtual std::unique_ptr MakeArgumentPointer( - const void* p_a, // input image - const void* p_b, // weight - const std::array& p_ds, - void* p_e, // output image - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; + typename CDEElementwiseOperation, + typename ComputeType = + decltype(UnpackDataType::value, + Number<0>, + ADataType>())> // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed +using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcb2ba6a4d7be839c11cf9794bb7beccf7845d3c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmKernelArgument +{ + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; +}; + +template +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp similarity index 83% rename from include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp rename to include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp index bf81ed9f5b1c2e6b46208ffede15370a3ae78088..5a4a9cac1e8ea78b6683b1c4ce9629a8c87b9802 100644 --- a/include/ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp @@ -13,7 +13,7 @@ namespace device { // For pooling which used indexable operation, such as MaxPool, MinPool...etc template -struct DeviceIndexPoolBwd : public BaseOperator +struct DeviceMaxPoolBwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_dout, @@ -22,7 +22,8 @@ struct DeviceIndexPoolBwd : public BaseOperator index_t dout_length, index_t din_length, std::vector window_lengths, - std::vector window_strides) = 0; + std::vector window_strides, + std::vector window_dilations) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..327acfeb53e9f468bfdef4713a8e3b748f876eca --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdDataPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a8804cbf0e86b48b9477978a4691f6159af3f09 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdGammaBeta : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdGammaBetaPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp similarity index 79% rename from include/ck/tensor_operation/gpu/device/device_normalization.hpp rename to include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp index 1f178f9fcb65ffdd7ab09146d94131c1d5c993f2..d252ad1d98f57b962e262efef74b5ae20e8a071c 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp @@ -14,12 +14,12 @@ namespace device { template -struct DeviceNormalization : public BaseOperator +struct DeviceNormalizationFwd : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const std::vector lengths, @@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator const std::vector gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -43,19 +45,19 @@ struct DeviceNormalization : public BaseOperator template -using DeviceNormalizationPtr = std::unique_ptr>; +using DeviceNormalizationFwdPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp deleted file mode 100644 index 16ca582b89c70a2819373fea6962daf34ff2b8a2..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace tensor_operation { -namespace device { - -enum struct GemmDlAlgorithm -{ - Default, // Uses DOT vector instructions - Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 4d599e8017991b4f1b1905de9acd23a219ac67f9..b32f3a8daadda21f5149c10198ae699a90047c0b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index fe971171811b8a1ecba4b53ed509a9b546d20ffb..64aa398d531e634c54c3d55d591789f72943bd0e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index d6b6405bb70af4699178b97f6a5b5052f136f4a5..d06eab1264684b595dc87dc1cf91a9f2b44a5056 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -75,7 +75,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, // DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 3dbe8c67226ec4e4f0dc5bf4ac02ac49ae73cf46..e950169ccfe58887351c4bfbc82af85a061287cf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 3df2ee38f2915b1112186e2a3f1fbfd64a3b174a..d6b92bc97a8d3eed0d0ce50ad26a9169c498671a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -84,7 +84,7 @@ __global__ void { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -324,8 +324,12 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD" << " NumGemmKPrefetchStage: " << NumGemmKPrefetchStage << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index f46237e005642d433c29f5ab0f027a224043dbef..3b62cf10a361144693255049520e891add60635a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index ad8e7956033db31db2992d6fa9f86cf16a29b150..e7e4668d92e72d985168462358eaa088b5dc385f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp index b826793c27be946a99d0d96a99d745fd6c618822..c3e0837722cfd3f2cec9f2c20a19165943a9ef6d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4c6546239beb642f680b5dc639185ebe2b6e0ccb --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Column to Image: +// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C] +// output : input image [G, N, Di, Hi, Wi, C] +// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C] +// output : input image [N, Di, Hi, Wi, G, C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceColumnToImageImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number{}; + static constexpr auto XIdx = Number{}; + + static constexpr auto spatial_offset = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Calculate number of independent filters for given conv params + static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len, + const index_t left_pad, + const index_t right_pad, + const index_t filter_len, + const index_t filter_stride, + const index_t filter_dilation, + const index_t image_offset) + { + const index_t x_eff = (filter_len - 1) * filter_dilation + 1; + const index_t next_filter_padded = + math::integer_divide_ceil(x_eff, filter_stride) * filter_stride; + // If filter_stride >= x_eff then each filter is independent + const index_t independent_filter_stride = + filter_stride >= x_eff ? filter_stride : next_filter_padded; + const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff; + // There are no independent filters + if(w_eff < 0) + return 0; + const index_t independent_kernels_num = w_eff / independent_filter_stride + 1; + return independent_kernels_num; + } + + // Make column form descriptor + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& gemm_g_m_k_strides, + const std::array& independent_filters, + const std::array& effs) + { + const index_t DoHoWo = ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2]; + // Calculate the appropriate stride for each set of independent filters + // in each dimension + const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * + gemm_g_m_k_strides[I1]; + const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * + output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1]; + const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * + output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * + gemm_g_m_k_strides[I1]; + // Create descriptor for independent filters in each dimension and + // then merge them into column form + if constexpr(NDimSpatial == 1) + { + const auto desc_gemm_form = + make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), + make_tuple(NStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 2) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), + make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 3) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx], + CZYX), + make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + } + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& image_offsets, + const std::array& independent_filters, + const std::array& effs) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + // Calculate descriptor only for independent filters + copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + // Modify pads to apply offsets + std::array input_left_pads_with_offset; + for(index_t i = 0; i < NDimSpatial; i++) + { + input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]); + } + // Modify input spatial lengths to apply offsets + for(index_t i = 0; i < NDimSpatial; i++) + { + a_g_n_c_wis_lengths[i + spatial_offset] -= + math::max(0, image_offsets[i] - input_left_pads[i]); + } + + // Strides to next independent filters + std::array independent_filter_strides; + for(index_t i = 0; i < NDimSpatial; i++) + { + index_t independent_filter_stride = + math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i]; + // If conv stride is greater than whole filter size, use conv stride + independent_filter_strides[i] = conv_filter_strides[i] >= effs[i] + ? conv_filter_strides[i] + : independent_filter_stride; + } + + // Calculate image form descriptor for the modified convolution problem + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K( + a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + // conv_filter_strides, + independent_filter_strides, + conv_filter_dilations, + input_left_pads_with_offset, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + InputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0]; + + const index_t x_eff = + (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; + const index_t y_eff = + NDimSpatial < 2 + ? I1 + : (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1; + const index_t z_eff = + NDimSpatial < 3 + ? I1 + : (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1; + + // Iterate over sets of independent filters + for(int z_img_offset = 0; z_img_offset < z_eff; + z_img_offset += conv_filter_strides[ZIdx]) + { + for(int y_img_offset = 0; y_img_offset < y_eff; + y_img_offset += conv_filter_strides[YIdx]) + { + for(int x_img_offset = 0; x_img_offset < x_eff; + x_img_offset += conv_filter_strides[XIdx]) + { + + std::array image_offsets; + std::array effs; + // Calculate the starting offset for a given set of + // independent filters + if constexpr(NDimSpatial == 1) + { + image_offsets = {x_img_offset}; + effs = {x_eff}; + } + if constexpr(NDimSpatial == 2) + { + image_offsets = {y_img_offset, x_img_offset}; + effs = {y_eff, x_eff}; + } + else if constexpr(NDimSpatial == 3) + { + image_offsets = {z_img_offset, y_img_offset, x_img_offset}; + effs = {z_eff, y_eff, x_eff}; + } + + std::array independent_filters; + for(index_t i = 0; i < NDimSpatial; i++) + { + independent_filters[i] = + GetNumberOfIndependentFilters(input_spatial_lengths[i], + input_left_pads[i], + input_right_pads[i], + filter_spatial_lengths[i], + conv_filter_strides[i], + conv_filter_dilations[i], + image_offsets[i]); + } + const index_t independent_filters_acum = ck::accumulate_n( + independent_filters.begin(), NDimSpatial, 1, std::multiplies<>()); + if(independent_filters_acum <= 0) + continue; + + const auto in_grid_desc_m_k = + MakeInputDescriptor_M_K(N, + C, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + gemm_g_m_k_strides, + independent_filters, + effs); + const auto out_grid_desc_m_k = + MakeOutDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + image_offsets, + independent_filters, + effs); + in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k); + out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k); + + const index_t x_idx = x_img_offset / conv_filter_strides[XIdx]; + const index_t y_idx = y_img_offset / conv_filter_strides[YIdx]; + const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx]; + + const index_t x_offset_with_pad = + math::max(0, x_img_offset - input_left_pads[XIdx]); + const index_t y_offset_with_pad = + math::max(0, y_img_offset - input_left_pads[YIdx]); + const index_t z_offset_with_pad = + math::max(0, z_img_offset - input_left_pads[ZIdx]); + + // Memory offsets to next set of independent filters, + // move to independent filters in each dimension + const index_t in_offset = + (x_idx + y_idx * output_spatial_lengths[XIdx] + + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) * + gemm_g_m_k_strides[I1]; + // Move to independent filters in appropriate dimensions + const index_t out_offset = + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + + y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] + + z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx]; + + const InputDataType* p_in_with_offset = + static_cast(p_in) + in_offset; + OutputDataType* p_out_with_offset = + static_cast(p_out) + out_offset; + p_in_container_.push_back(p_in_with_offset); + p_out_container_.push_back(p_out_with_offset); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++) + { + std::cout << in_grid_desc_m_k_container_[i] << std::endl; + std::cout << out_grid_desc_m_k_container_[i] << std::endl; + } + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + std::vector in_grid_desc_m_k_container_; + std::vector out_grid_desc_m_k_container_; + + std::vector p_in_container_; + std::vector p_out_container_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float elapsed_time = 0.f; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + // Execute each set of independent filters + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_container_[i]); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_; + elapsed_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_container_[i], + arg.p_in_container_[i], + arg.out_grid_desc_m_k_container_[i], + arg.p_out_container_[i], + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + } + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + using namespace tensor_layout::convolution; + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + bool valid = true; + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + valid &= GridwiseTensorRearrangeKernel::CheckValidity( + arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]); + } + return valid; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceColumnToImage" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62f5454b4eae073216258009786fcdd8a6df0179 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,852 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD_Xdl_CShuffle + : public DeviceContractionMultipleABD +{ + using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, + const std::vector& a_ms_ks_strides_) + { + assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && + a_ms_ks_strides_.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); + }, + Number{}); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, + const std::vector& b_ns_ks_strides_) + { + assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && + b_ns_ks_strides_.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); + }, + Number{}); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, + const std::vector& e_ms_ns_strides_) + { + assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && + e_ms_ns_strides_.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto + MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AsGridDesc_M_K = remove_cvref_t; + using BsGridDesc_N_K = remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t; + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + // using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + // using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + + // for sanity check of vector memory access + for(index_t i = 0; i < NumATensor; ++i) + { + as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1; + as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1; + as_max_read_elems_[i] = + CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + } + + for(index_t i = 0; i < NumBTensor; ++i) + { + bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1; + bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1; + bs_max_read_elems_[i] = + CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + } + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; + ds_max_read_elems_[i] = + CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + } + + e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1; + e_max_write_elems_ = CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + } + + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Describe whether the last part of a given dimension of A/B/D/E is consecutive + // in the memory or not. + std::array as_mz_consecutive_; + std::array as_kz_consecutive_; + std::array bs_nz_consecutive_; + std::array bs_kz_consecutive_; + std::array ds_nz_consecutive_; + bool e_nz_consecutive_; + + std::array as_max_read_elems_; + std::array bs_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + bool valid_as_access = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + const bool valid_a_vector_size = + arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i]; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i]; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!(valid_a_vector_size && valid_a_access_dim)) + { + valid_as_access = false; + } + }); + if(!valid_as_access) + { + return false; + } + + bool valid_bs_access = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + const bool valid_b_vector_size = + arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i]; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i]; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!(valid_b_vector_size && valid_b_access_dim)) + { + valid_bs_access = false; + } + }); + if(!valid_bs_access) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + if(!(valid_d_vector_size && valid_d_access_dim)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_nz_consecutive_; + if(!(valid_e_vector_size && valid_e_access_dim)) + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + d_ms_ns_lengths, + d_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides, + const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + as_ms_ks_lengths, + as_ms_ks_strides, + bs_ns_ks_lengths, + bs_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceContractionMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index deb938156d23d89297cb8e132dfcad3d5ad45c1c..1f65afed3d72d0b5069409f7ccd07cd09edf1a60 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -53,7 +54,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -145,7 +146,8 @@ template + typename ComputeDataType = ADataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceContractionMultipleD_Xdl_CShuffle : public DeviceContractionMultipleD + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; @@ -181,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return generate_tuple([&](auto i) { return vec[i]; }, num); }; - const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); // dimension Ids for M0, M1, ... @@ -192,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle typename arithmetic_sequence_gen::type{}; // lengths for M0, M1, ... - const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); // lengths for K0, K1, ... - const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( @@ -313,6 +316,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -379,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b_grid, std::array p_ds_grid, void* p_e_grid, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -394,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, @@ -407,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - cde_element_op_{cde_element_op}, - a_mz_stride_{}, - a_kz_stride_{}, - b_nz_stride_{}, - b_kz_stride_{}, - ds_nz_stride_{}, - e_nz_stride_{} + cde_element_op_{cde_element_op} { // populate pointer, batch stride, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -444,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle } // for sanity check of vector memory access - a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; - a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; + a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1; + a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1; + a_max_read_elems_ = + CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); - b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; - b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; + b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1; + b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1; + b_max_read_elems_ = + CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); for(index_t i = 0; i < NumDTensor; ++i) { - ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; + ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1; + ds_max_read_elems_[i] = + CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); } - e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; + e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1; + e_max_write_elems_ = + CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); } void Print() const @@ -495,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BElementwiseOperation b_element_op_; CDEElementwiseOperation cde_element_op_; - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - index_t a_mz_stride_; - index_t a_kz_stride_; - index_t b_nz_stride_; - index_t b_kz_stride_; - std::array ds_nz_stride_; - index_t e_mz_stride_; - index_t e_nz_stride_; + // Describe whether the last part of a given dimension of A/B/D/E is consecutive + // in the memory or not. + bool a_mz_consecutive_; + bool a_kz_consecutive_; + bool b_nz_consecutive_; + bool b_kz_consecutive_; + std::array ds_nz_consecutive_; + bool e_nz_consecutive_; + + index_t a_max_read_elems_; + index_t b_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; }; // Invoker @@ -591,7 +602,9 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } - if(ck::get_device_name() != "gfx90a" && std::is_same::value) + if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && + ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && + std::is_same::value) { return false; } @@ -610,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), "wrong!"); - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) + const bool valid_a_vector_size = + arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_; + const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!(valid_a_vector_size && valid_a_access_dim)) { - if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } + return false; } - // vector memory access of B: could be on N or BK1 dimension - if constexpr(BBlockTransferSrcVectorDim == 1) - { - if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else + const bool valid_b_vector_size = + arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_; + const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!(valid_b_vector_size && valid_b_access_dim)) { - if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } + return false; } - // vector memory access of Ds: always on NPerBlock dimension - bool valid_d_access = true; - + bool valid_ds_access = true; static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_nz_consecutive_[i]; + if(!(valid_d_vector_size && valid_d_access_dim)) { - valid_d_access = false; + valid_ds_access = false; } }); - - if(valid_d_access == false) + if(!valid_ds_access) { return false; } - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_nz_consecutive_; + if(!(valid_e_vector_size && valid_e_access_dim)) { return false; } @@ -686,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -702,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, @@ -723,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle const void* p_b, std::array p_ds, void* p_e, - const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_lengths, const std::vector& a_ms_ks_strides, const std::vector& b_ns_ks_lengths, const std::vector& b_ns_ks_strides, @@ -739,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle p_b, p_ds, p_e, - a_ms_ns_lengths, + a_ms_ks_lengths, a_ms_ks_strides, b_ns_ks_lengths, b_ns_ks_strides, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..838305f187adcad799974cad5085ba0596e8dfbc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * Calculates the maximum number of subsequent elements of the fast changing dimension + * that are consecutive in memory. + * + * Example: + * NumDimM = 2, NumDimK = 3 + * A shape = [ 2, 3, 4, 5, 6] + * A strides = [360, 120, 30, 6, 1] + * | M | | K | + * It follows from strides that K is FCD and all the subsequent elements of K are consecutive + * in memory. + * But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be + * consecutive in memory. + * + * Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions. + */ +template +auto CalculateMaxRead(const std::vector& lengths, const std::vector& strides) +{ + if(lengths.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of lengths in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + if(strides.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of strides in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // Determine the beginning and end idx of the group representing the FCD. + index_t begin_idx, end_idx; + if(strides[NumDim1 - 1] == 1) + { + begin_idx = 0; + end_idx = NumDim1 - 1; + } + else if(strides[NumDim1 + NumDim2 - 1] == 1) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + } + else + { + // The dimension consecutive in memory is not the last dimension of any group, so only + // one element can be read/written at once. + return 1; + } + + index_t consecutive_stride = 1; + for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) + { + if(strides[dim_idx] == consecutive_stride) + { + consecutive_stride *= lengths[dim_idx]; + } + else + { + break; + } + } + const index_t max_subsequent_elems = consecutive_stride; + return max_subsequent_elems; +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index a8e586b20c38a40438636c5f9cafafb5c34409a9..55cf8df27231be7e46ed28172e20870a246c7035 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -56,7 +56,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 3178f73f4b530e67b399cd0ae71064a4b66fa837..d95671be7e71bdd83fbe0e27efbbdf1a16168a85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) + if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || + ck::is_navi3_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67b6f87465ab9cfcb2672fe4ce9dc32504e4f903 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceElementwise3dImpl : public DeviceElementwise +{ + static constexpr index_t NumDim = NumDim_m + NumDim_n + NumDim_k; + + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + } + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + } + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_MNK(Desc_MNK desc_mnk, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n, + index_t num_threads_k) + { + std::ignore = blockSize; + std::ignore = gridSize; + + const auto m = desc_mnk.GetLength(I0); + const auto n = desc_mnk.GetLength(I1); + const auto k = desc_mnk.GetLength(I2); + + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + const index_t loop_step_k = num_threads_k * KPerThread; + + const auto pad_m = math::integer_least_multiple(m, loop_step_m) - m; + const auto pad_n = math::integer_least_multiple(n, loop_step_n) - n; + const auto pad_k = math::integer_least_multiple(k, loop_step_k) - k; + + const auto desc_mnk_pad = + transform_tensor_descriptor(desc_mnk, + make_tuple(make_right_pad_transform(m, pad_m), + make_right_pad_transform(n, pad_n), + make_right_pad_transform(k, pad_k)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return desc_mnk_pad; + } + + static auto MakeDescriptor_MNK(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize, + index_t num_threads_m, + index_t num_threads_n, + index_t num_threads_k) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDim_m, 1>::type(); + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type(); + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type(); + + const auto mLengths = get_container_subset(tupleOfShape, mDimIds); + const auto nLengths = get_container_subset(tupleOfShape, nDimIds); + const auto kLengths = get_container_subset(tupleOfShape, kDimIds); + + // merge nd to 3d desc - [s0 * s1 * ...] + if constexpr(NumDim > 3) + { + const auto desc_mnk = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(mLengths), + make_merge_transform(nLengths), + make_merge_transform(kLengths)), + make_tuple(mDimIds, nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return PadDescriptor_MNK( + desc_mnk, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k); + } + else + return PadDescriptor_MNK( + desc, gridSize, blockSize, num_threads_m, num_threads_n, num_threads_k); + } + + template + static auto GenerateInOutGrid3dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 3) + { + return MakeDescriptor_MNK({1, 1, 1}, {1, 1, 1}, 1, 1, 1, 1, 1); + } + else + { + return MakeDescriptor_MNK({1}, {1}, 1, 1, 1, 1, 1); + }; + }, + Number{}); + } + + using OutGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number{})); + using InGrid3dDescTuple = decltype(GenerateInOutGrid3dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_3D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + blockSize_(256) + { + static_assert(NumDim_m > 0, ""); + static_assert(NumDim_n > 0, ""); + static_assert(NumDim_k > 0, ""); + + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config) * arg.blockSize_; + index_t num_threads_m = gridSize / (16 * 16); + index_t num_threads_n = 16; + index_t num_threads_k = 16; + + auto in_grid_3d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_MNK(arg.lengths_, + arg.inStridesArray_[I.value], + gridSize, + arg.blockSize_, + num_threads_m, + num_threads_n, + num_threads_k); + }, + Number{}); + + auto out_grid_3d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_MNK(arg.lengths_, + arg.outStridesArray_[I.value], + gridSize, + arg.blockSize_, + num_threads_m, + num_threads_n, + num_threads_k); + }, + Number{}); + + const auto kernel = kernel_elementwise_3d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_3d_desc_tuple, + out_grid_3d_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + num_threads_m, + num_threads_n, + num_threads_k); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) + { + return false; + } + + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t vectorDim) { + if(strides[vectorDim] == 1 && + (lengths[vectorDim] % scalarPerVector == 0 || + lengths[vectorDim] % scalarPerVector == lengths[vectorDim])) + { + return true; + } + + if(strides[vectorDim] >= scalarPerVector) + { + return true; + } + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + valid = valid && IsScalarPerVectorValid(pArg->lengths_, + pArg->inStridesArray_[I.value], + InScalarPerVectorSeq::At(I), + NumDim_m - 1); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + valid = valid && IsScalarPerVectorValid(pArg->lengths_, + pArg->outStridesArray_[I.value], + OutScalarPerVectorSeq::At(I), + NumDim - 1); + }); + + return valid; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp index e203468c6174f0ae2c8bde33a716d0665f2e8b6c..37867f1eaae934dc568b0c1503adb356bd468e9a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp @@ -296,6 +296,28 @@ struct DeviceElementwiseImpl { return std::make_unique(); }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<" ; + str << "NumDim_" << NumDim << ","; + str << "MPerThread_" << MPerThread << ","; + + str << "InScalarPerVector"; + static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; }); + str << ","; + str << "OutScalarPerVector"; + static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; }); + + str << ">"; + // clang-format on + + return str.str(); + } + }; // namespace device } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5e0f5e288ee935161ce2a91a438de643a62d17dd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseImpl : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NumDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + template + static auto GenerateInOutGrid1dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 1) + { + return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); + } + else + { + return MakeDescriptor_M({1}, {1}, 1, 1); + }; + }, + Number{}); + }; + + using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + unary_op_(unary_op), + scale_op_(scale_op), + blockSize_(256) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + UnaryOperation unary_op_; + Scale scale_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + const auto kernel = kernel_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + arg.unary_op_, + arg.scale_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector) { + if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) + return true; + + if(strides.back() != 1 && scalarPerVector == 1) + return true; + + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) + valid = false; + }); + + return valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eedf384cd9c99fce8c32ecc54be476a036165f86..bac124a2f1f985d16f2f2a1fed28f8bdadf7b3ce 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -11,7 +11,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,7 +59,6 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, - GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default, enable_if_t< is_same_v && is_same_v && @@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm; + CThreadTransferDstScalarPerVector>; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -276,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -402,8 +405,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -429,8 +431,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -456,8 +457,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -492,19 +492,52 @@ struct DeviceGemmDl : public DeviceGemm::value) + { + constexpr auto A_K_vec_length = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3); + if(arg.K_raw_ % A_K_vec_length != 0) + { + return false; + } + } + else + { + constexpr auto A_M_vec_lenght = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2); + if(arg.M_raw_ % A_M_vec_lenght != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + constexpr auto B_N_vec_lenght = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2); + if(arg.N_raw_ % B_N_vec_lenght != 0) + { + return false; + } + } + else { - if(ck::get_device_name() == "gfx1030") + constexpr auto B_K_vec_length = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3); + if(arg.K_raw_ % B_K_vec_length != 0) { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + return false; } - return false; } - if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() || + ck::is_navi3_supported()) { return GridwiseGemm::CheckValidity( arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp deleted file mode 100644 index 336ce31e8214391183a94f5af6951852bc5ed067..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,133 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - GemmSpecialization GemmSpec, - index_t BlockSize, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t K1, - index_t M1PerThread, - index_t N1PerThread, - index_t KPerThread, - typename M1N1ThreadClusterM1Xs, - typename M1N1ThreadClusterN1Xs, - typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - typename ABlockTransferSrcVectorTensorContiguousDimOrder, - typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - typename BBlockTransferSrcVectorTensorContiguousDimOrder, - typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - typename CThreadTransferSrcDstAccessOrder, - index_t CThreadTransferSrcDstVectorDim, - index_t CThreadTransferDstScalarPerVector, - enable_if_t< - is_same_v && - is_same_v && - is_same_v, - bool> = false> -struct DeviceGemmDlDpp8 : public DeviceGemmDl - -{ - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmDlDpp8" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock << ", " - << K1 << ", " - << M1PerThread << ", " - << N1PerThread << ", " - << KPerThread - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..77f28eca9d080af0ab961e3a2ba8dc188e7f31a3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::is_navi2_supported() || ck::is_navi3_supported()) + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21914d466d7d759ceaad96bf505b467bfdfbc6d9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,768 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD +{ + using DeviceOp = DeviceGemmMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + +#if 0 + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideAs, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideAs)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideBs)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideBs, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } +#endif + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // desc for problem definition + using AsGridDesc_M_K = + remove_cvref_t( + {}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t( + {}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t( + {}, {}, {}))>; + using EGridDesc_M_N = + decltype(GridwiseGemm::template MakeEGridDescriptor_M_N(1, 1, 1)); + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideE)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + GridwiseGemm::template MakeAGridDescriptor_M_K( + MRaw, KRaw, StrideAs[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + GridwiseGemm::template MakeBGridDescriptor_N_K( + KRaw, NRaw, StrideBs[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + // std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl; + // std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl; + // static_for<0, NumDTensor, 1>{}( + //[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + // std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + bool all_valid = true; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + // check vector load of Ds + // only support RowMajor for now + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index ad51096db7f6d8aa1cd8b4a8c7590ff385009943..8490fa52fd6e394ebd74c44166caa267751c42c4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -50,9 +50,8 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); @@ -552,11 +551,8 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD( @@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle return (workspace_size); }; - void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override { Argument* pArg_ = dynamic_cast(pArg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 916f29a904eba4172b13651249263d085b16f378..bb2db930c8ea7b1ed34b17a7cde56a2c7c6daafc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -61,7 +61,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 44b3518e2c950e15ec3d5cfd240c443029987622..fd90c7f1ea8720e4c0f352f346f7f51c0ad689b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -484,8 +484,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 6502155f87aeba859f79a2321dfad7a38b4dfbf2..f5735b39cc532605cdf14dd8854e016e626d8d36 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -22,7 +22,8 @@ namespace ck { template (p_a_grid, @@ -145,7 +146,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeDataType = EDataType> struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad + : public DeviceGemmMultipleD +{ + static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of Ds. + // For now, only the RowMajor layout is supported. + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index f64450b75f272dbb0f9ef4066e055a2f59c3879e..98d14caa6d235c2aa4cb0589d0f8951b31625abc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -411,8 +411,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index ebbb1d38c000ba44ce53f9a8110114c41a299b14..5188ece3335d9be07e4b64c74811e23fee86c382 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -184,7 +184,7 @@ struct DeviceGemmXdl : public DeviceGemm || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 86bc7b5bbb0640c4c2bf0e947b3adae56b358f13..dd41f4bca0a3ec7ceb27d43db553b0a6a5b265bb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -66,7 +66,8 @@ template + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; @@ -276,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm +{ + static constexpr auto I1 = Number<1>{}; + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + // arg.p_ds_grid_, + ck::Tuple<>{}, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return Argument{p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return std::make_unique(p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d49c63f1477b2711bba6f2e8ed1eca2296e17d85 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffleV2" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index e7365daed2d72b961f508ef343dc0a0b9e41113b..7f28ec768083dad963592bbfeda6d91b0ab95580 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -59,7 +59,10 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename LDSTypeA = ComputeType, + typename LDSTypeB = ComputeType> struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + CElementwiseOperation, + ComputeType> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -78,7 +82,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; - using Argument = typename GridwiseGemm::Argument; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Invoker @@ -154,9 +206,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) @@ -179,7 +240,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -189,7 +253,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -202,7 +269,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -212,7 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; Run(kernel); } @@ -260,12 +333,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(static_cast(p_a), @@ -310,8 +386,11 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] + << ", PipelineVersion: " << PipelineVersionToString[PipelineVer]; + + return str.str(); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd33e577bfdd081312b67c1db5a9d9b176c7a0bf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp index 8de42ba9efcdc018d0d945ace9aa3e2da9a41cbf..51b8958d61fe99a67ce1a9815c8b1495c6a8967a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(pArg); @@ -241,9 +243,7 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 0b1bf0f1792cec386ffe3a2cdf1444d18f14e4cc..cc022b89c5ed472758300a1cd4e19cc5b2924f09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -38,7 +38,7 @@ __global__ void const CDEElementwiseOperation cde_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -355,9 +355,13 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b3de153c330faafd42bd30e49dce52ed50ec755 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,878 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // for check validity + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap( + e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_ctile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_ctile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 8a0a4453702a081c02027653e3c13fa61b8e8c46..c0fa9ad882a1263e44d76103b479d7b4af3f1336 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -24,51 +25,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -131,7 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -242,7 +198,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD + CDEElementwiseOp, + AComputeType, + BComputeType> { - // FIXME + // TODO: Extend support for more spatial dimensions. static_assert(NDimSpatial == 2 || NDimSpatial == 3, "wrong! only implemented for 2D and 3D now"); @@ -265,7 +225,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static constexpr index_t NumDTensor = DsDataType::Size(); - // TODO make A/B datatype different + // TODO: Add support for different A and B data types. using ABDataType = ADataType; static constexpr auto I0 = Number<0>{}; @@ -280,6 +240,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 BK1, MPerBlock, NPerBlock, + KPerBlock, DoPadGemmM, DoPadGemmN>{}; @@ -355,7 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, // TODO: distinguish A/B datatype + ABDataType, + ABDataType, + AComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -395,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; + LoopSched, + PipelineVersion::v1, + BComputeType>; template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) @@ -712,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::vector block_2_etile_map_container_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOp a_element_op_; @@ -781,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, Block2ETileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp similarity index 62% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 198751cdf350f7f9d828edda641b8ef217e98af6..534467b9591a5a9cc086168d55e24babad1a6f37 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -12,8 +12,10 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,32 +23,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template {}, integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_kbatch_k0_m0_m1_k1; + ignore = b_grid_desc_kbatch_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; +#endif } template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> +struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight { - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; + // 1d + static constexpr bool is_NWGK_GKXC_NWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNWK_GKXC_GNWC = + is_same_v && + is_same_v && + is_same_v; + // 2d + static constexpr bool is_NHWGK_GKYXC_NHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNHWK_GKYXC_GNHWC = + is_same_v && + is_same_v && + is_same_v; + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + using DeviceOp = DeviceGroupedConvBwdWeight_Dl; using ADataType = OutDataType; using BDataType = InDataType; @@ -176,6 +185,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto spatial_offset = I3; + static constexpr auto K1Number = Number{}; static constexpr auto GemmK1Number = K1Number; @@ -195,12 +206,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -209,90 +220,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Wi = input_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - const index_t ConvStrideW = conv_filter_strides[0]; - const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset]; + const index_t InLeftPadW = input_left_pads[I0]; + const index_t InRightPadW = input_right_pads[I0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + const index_t ConvDilationW = conv_filter_dilations[I0]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; const index_t GemmKTotal = N * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weights tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -321,38 +344,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -361,103 +389,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1]; + + const index_t InLeftPadH = input_left_pads[I0]; + const index_t InLeftPadW = input_left_pads[I1]; + const index_t InRightPadH = input_right_pads[I0]; + const index_t InRightPadW = input_right_pads[I1]; + const index_t ConvStrideH = conv_filter_strides[I0]; + const index_t ConvStrideW = conv_filter_strides[I1]; + const index_t ConvDilationH = conv_filter_dilations[I0]; + const index_t ConvDilationW = conv_filter_dilations[I1]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -488,39 +524,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -529,110 +570,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; - - const index_t Z = filter_spatial_lengths[0]; - const index_t Y = filter_spatial_lengths[1]; - const index_t X = filter_spatial_lengths[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2]; + const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2]; + const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2]; + + const index_t InLeftPadD = input_left_pads[I0]; + const index_t InLeftPadH = input_left_pads[I1]; + const index_t InLeftPadW = input_left_pads[I2]; + const index_t InRightPadD = input_right_pads[I0]; + const index_t InRightPadH = input_right_pads[I1]; + const index_t InRightPadW = input_right_pads[I2]; + const index_t ConvStrideD = conv_filter_strides[I0]; + const index_t ConvStrideH = conv_filter_strides[I1]; + const index_t ConvStrideW = conv_filter_strides[I2]; + const index_t ConvDilationD = conv_filter_dilations[I0]; + const index_t ConvDilationH = conv_filter_dilations[I1]; + const index_t ConvDilationW = conv_filter_dilations[I2]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InDStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * Z * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -672,27 +723,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end @@ -701,22 +757,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); } template ::type = false> static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, @@ -785,11 +841,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl WeiDataType* p_wei_grid, const OutDataType* p_out_grid, const std::array& a_g_n_c_wis_lengths, // input - const std::array& /*a_g_n_c_wis_strides*/, + const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, // weight - const std::array& /*b_g_k_c_xs_strides*/, + const std::array& b_g_k_c_xs_strides, const std::array& e_g_n_k_wos_lengths, // output - const std::array& /*e_g_n_k_wos_strides*/, + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -809,38 +865,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl a_element_op_{out_element_op}, b_element_op_{wei_element_op}, c_element_op_{in_element_op}, - Conv_G_{a_g_n_c_wis_lengths[0]}, - Conv_N_{a_g_n_c_wis_lengths[1]}, - Conv_K_{b_g_k_c_xs_lengths[1]}, - Conv_C_{a_g_n_c_wis_lengths[2]}, - input_spatial_lengths_{}, - filter_spatial_lengths_{}, - output_spatial_lengths_{}, + Conv_G_{a_g_n_c_wis_lengths[I0]}, + Conv_K_{b_g_k_c_xs_lengths[I1]}, + Conv_C_{a_g_n_c_wis_lengths[I2]}, + filter_lengths_{b_g_k_c_xs_lengths}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { - constexpr index_t spatial_offset = 3; - std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, - end(a_g_n_c_wis_lengths), - begin(input_spatial_lengths_)); - std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, - end(b_g_k_c_xs_lengths), - begin(filter_spatial_lengths_)); - std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, - end(e_g_n_k_wos_lengths), - begin(output_spatial_lengths_)); - const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -863,24 +905,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - Conv_N_ * Conv_K_ * - std::accumulate(begin(output_spatial_lengths_), - end(output_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - Conv_N_ * Conv_C_ * - std::accumulate(begin(input_spatial_lengths_), - end(input_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideC_ = - Conv_K_ * Conv_C_ * - std::accumulate(begin(filter_spatial_lengths_), - end(filter_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0]; } const ADataType* p_a_grid_; @@ -899,7 +926,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; // element-wise op OutElementwiseOperation a_element_op_; @@ -908,13 +935,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // for checking IsSupportedArgument() const index_t Conv_G_; - const index_t Conv_N_; const index_t Conv_K_; const index_t Conv_C_; - std::array input_spatial_lengths_; - std::array filter_spatial_lengths_; - std::array output_spatial_lengths_; + std::array filter_lengths_; const std::array& conv_filter_strides_; const std::array& conv_filter_dilations_; const std::array& input_left_pads_; @@ -974,7 +998,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop, has_double_loop>; @@ -1036,10 +1060,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static bool IsSupportedArgument(const Argument& arg) { - // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) + + // DL version only supports split_k equal to 1 + if(arg.k_batch_ != 1) + return false; + + if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) || + (NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) || + (NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)))) { return false; } @@ -1050,8 +1078,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { - if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && - arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + if(!(arg.filter_lengths_[spatial_offset + i] == 1 && + arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) { return false; } @@ -1206,7 +1235,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" + str << "DeviceGroupedConvBwdWeight_Dl" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8850b13d0a47cce571c2018a9ed5cd43f1572e6f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,876 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template ::type = false> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto GemmK1Number = Number{}; + static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; + const index_t GemmKPad = GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Pad + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto GetABCGridDesc() + { + const index_t dim = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CDataType, + Tuple<>, + CDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + Tuple<>, + CGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})); + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( + CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap( + c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop>; + + using EmptyTuple = Tuple<>; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + EmptyTuple{}, // Ds + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{}, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // TODO: Add support for split_k > 1 + if(arg.k_batch_ != 1) + { + return false; + } + + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 71c528e4c92e339b334d78faab82fe10c7bf914e..26b0eae915b2e8a331f0731566c35495bdc35b2f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,34 +22,9 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template (compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -163,7 +139,9 @@ template + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight_Xdl_CShuffle : public DeviceGroupedConvBwdWeight + OutElementwiseOperation, + ComputeTypeA, + ComputeTypeB> { using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; @@ -1045,7 +1025,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, @@ -1090,7 +1071,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, - true>; + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -1212,13 +1197,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; index_t M01_; index_t N01_; - InElementwiseOperation a_element_op_; - OutElementwiseOperation b_element_op_; + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() @@ -1281,7 +1266,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype + ADataType, + BDataType, CDataType, OutElementwiseOperation, InElementwiseOperation, @@ -1290,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch<>, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -1337,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } if constexpr(NDimSpatial == 1) { if constexpr(!is_GNWK_GKXC_GNWC) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 8b22bd209043e900cd594bcac7923833e9a26536..c3023301f39f9edde9ecc9c0aadde9b14609861b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -15,10 +15,11 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -134,9 +90,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ - defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -260,18 +215,18 @@ template struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK - : public DeviceGroupedConvFwdMultipleD + : public DeviceGroupedConvFwdMultipleABD { using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; @@ -581,7 +536,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DefaultBlock2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -645,7 +600,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, DefaultBlock2CTileMap, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, has_double_loop>; @@ -710,11 +665,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK namespace ctc = tensor_layout::convolution; // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || - ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || - ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || - ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")) + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_navi2_supported() || ck::is_navi3_supported())) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index f18fbcfe4b51031009f6329baac2bc1ac209e1fa..d731e5ddac321dc093501b40c935b1ebb162a856 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -106,8 +106,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -601,9 +601,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N( + const std::array& e_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = + std::conditional_t, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( + e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if constexpr(isMultiA || isMultiB) + { + // Generate tuples with grid descriptors for each A and B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number{}); + + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AGridPointer, + BGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + decltype(as_grid_desc_ak0_m_ak1), + decltype(bs_grid_desc_bk0_n_bk1), + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ADataType*, + const BDataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple + arg.p_bs_grid_.At(I0), // Pass just B descriptor instead of tuple + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + } + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::is_lds_direct_load_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index caa18b709cea604930c3c81c80cdbebc5d9cfaa1..ab1c4fc08f60bfae145c80104e593e2d55115971 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -156,7 +156,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -813,8 +813,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || - get_device_name() == "gfx941" || get_device_name() == "gfx942") + else if(ck::is_lds_direct_load_supported()) { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) @@ -834,7 +833,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -851,7 +850,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -1090,7 +1089,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 7c726cd85128ff61ac4eb675d5c3dee9e9de3315..ba2a4b0f7aedbd532fa597873870e831d9cb33f3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -15,10 +15,11 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -27,55 +28,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -} // namespace - // // @brief Device Convolution operation. // @@ -140,18 +92,18 @@ template struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle - : public DeviceGroupedConvFwdMultipleD + : public DeviceGroupedConvFwdMultipleABD { using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; @@ -476,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op AElementwiseOperation a_element_op_; @@ -519,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< GridwiseOp, ADataType, BDataType, @@ -533,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -579,8 +531,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle namespace ctc = tensor_layout::convolution; // check device - if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102") + if(ck::is_navi3_supported()) { if constexpr(!(is_same_v || is_same_v)) { @@ -599,7 +550,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -616,7 +567,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -679,8 +630,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) + is_same_v || is_same_v) { const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index c80598c4e335aaf8f448d3ad1af6c0e6172768d1..e88ca9129302fa775cec5c3cf8f632de29c19fce 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -3,200 +3,20 @@ #pragma once -#include -#include -#include -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/host_utility/io.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -/* - * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. - * - * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix - * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly - * strided batched, but we can easily extend to other layouts. The returned offset can be either \p - * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB - * limitations. - * - * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and - * returns the 2D index of the tile that it computes. \see - * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). - * - * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 - * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid - * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link - * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for - * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the - * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. - * - * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. - * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to - * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). - * - */ -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) - // offset base pointer for each work-group - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - DsPointer p_ds_grid_grp; - - static constexpr index_t NumDTensor = - DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); - - static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); -#else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_ctile_map; -#endif -} - -} // namespace - // // @brief Device Convolution operation. -// +// @note This structure is deprecated (left for backwards compatibility). Please use +// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle. // Supports: // @li Forward convolution with up to 3 spatial dimentions // @li Input tensor in GNWC data format @@ -255,711 +75,61 @@ template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed LoopScheduler LoopSched = make_default_loop_scheduler()> -struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle - : public DeviceGroupedConvFwdMultipleD -{ - using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; - - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto conv_to_gemm_transformer = - TransformConvFwdToGemm{}; - - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const auto in_gemmmraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const auto in_gemmm_gemmk_desc = - matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - - return in_gemmm_gemmk_desc; - } - - template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) - { - const auto wei_gemmnraw_gemmkraw_desc = - conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides); - - const auto wei_gemmn_gemmk_desc = - matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - - return wei_gemmn_gemmk_desc; - } - - template - static auto - MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides) - { - const auto out_gemmmraw_gemmnraw_desc = - conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides); - - const auto out_gemmm_gemmn_desc = - matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); - - return out_gemmm_gemmn_desc; - } - - static auto MakeDsGridDescriptor_M_N( - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides) - { - return generate_tuple( - [&](auto i) { - using DLayout = remove_cvref_t>; - - return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], - ds_g_n_k_wos_strides[i]); - }, - Number{}); - } - - // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t({}, {}))>; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; - - // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BGridDesc_BK0_N_BK1 = - remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - - // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; - - // Argument - struct Argument : public BaseArgument - { - Argument(const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& - ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& - ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) - : p_a_grid_{static_cast(p_a)}, - p_b_grid_{static_cast(p_b)}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e)}, - num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, - e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - compute_ptr_offset_of_batch_{}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op}, - a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, - a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, - b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, - b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, - ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, - ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, - e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, - e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, - conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, - input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} - { - // A/B/E Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - // populate pointer, batch stride, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - using DDataType = remove_cvref_t>; - - // D pointer - p_ds_grid_(i) = static_cast(p_ds[i]); - - // D batch stride - compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; - - // D desc - ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N( - ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); - }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - } - } - - void Print() const - { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); - std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; - } - - // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - index_t num_group_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - - // for checking IsSupportedArgument() - std::array a_g_n_c_wis_lengths_; - std::array a_g_n_c_wis_strides_; - std::array b_g_k_c_xs_lengths_; - std::array b_g_k_c_xs_strides_; - std::array, NumDTensor> ds_g_n_k_wos_lengths_; - std::array, NumDTensor> ds_g_n_k_wos_strides_; - std::array e_g_n_k_wos_lengths_; - std::array e_g_n_k_wos_strides_; - std::array conv_filter_strides_; - std::array conv_filter_dilations_; - std::array input_left_pads_; - std::array input_right_pads_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error( - "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; - - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_grouped_conv_fwd_multiple_d_xdl_cshuffle< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - Block2ETileMap, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_, - arg.compute_ptr_offset_of_batch_); - }; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - namespace ctc = tensor_layout::convolution; - - // check device - if(get_device_name() == "gfx908") - { - if constexpr(!(is_same_v || is_same_v || - is_same_v)) - { - return false; - } - } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || - get_device_name() == "gfx941" || get_device_name() == "gfx942") - { - if constexpr(!(is_same_v || is_same_v || - is_same_v || is_same_v)) - { - return false; - } - } - else - { - return false; - } - - // check ConvolutionForwardSpecialization - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - // check if it's 1x1, stride=1 conv - for(index_t i = 0; i < NDimSpatial; ++i) - { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; - const index_t ConvStride = arg.conv_filter_strides_[i]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; - - if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) - { - return false; - } - } - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization::Filter1x1Pad0) - { - // check if it's 1x1 conv - for(index_t i = 0; i < NDimSpatial; ++i) - { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; - const index_t LeftPad = arg.input_left_pads_[i]; - const index_t RightPad = arg.input_right_pads_[i]; - - if(!(X == 1 && LeftPad == 0 && RightPad == 0)) - { - return false; - } - } - } - - // check vector access of A - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t C = arg.a_g_n_c_wis_lengths_[2]; - - if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - return false; - } - - // check vector access of B - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - - { - const index_t C = arg.b_g_k_c_xs_lengths_[2]; - - if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - } - else - { - return false; - } - - // check vector access of Ds - bool valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - // FIXME: layout - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; - - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) - { - valid = false; - } - } - else - { - valid = false; - } - }); - - if(!valid) - { - return false; - } - - // check vector access of E - if constexpr(is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v || is_same_v || - is_same_v) - { - const index_t K = arg.e_g_n_k_wos_lengths_[2]; - - if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) - { - return false; - } - } - else - { - return false; - } - - // check Gridwise GEMM - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument( - const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) - { - return Argument{p_a, - p_b, - p_ds, - p_e, - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths, - ds_g_n_k_wos_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - std::unique_ptr MakeArgumentPointer( - const void* p_a, - const void* p_b, - const std::array& p_ds, - void* p_e, - const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>& ds_g_n_k_wos_lengths, - const std::array, NumDTensor>& ds_g_n_k_wos_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) override - { - return std::make_unique(p_a, - p_b, - p_ds, - p_e, - a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - ds_g_n_k_wos_lengths, - ds_g_n_k_wos_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - a_element_op, - b_element_op, - cde_element_op); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle - << ">"; - // clang-format on - - return str.str(); - } -}; +using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + ComputeDataType, + LoopSched>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..35f4393e369d83175d884bcbbba20e8c50d2eac8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ComputePtrOffsetOfStridedBatch +{ +}; + +template +struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideAs), + BatchStrideB_(BatchStrideBs), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const + { + Array as_offset; + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_offset(i) = g_idx * static_cast(BatchStrideA_[i]); }); + return as_offset; + } + + __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const + { + Array bs_offset; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_offset(i) = g_idx * static_cast(BatchStrideB_[i]); }); + return bs_offset; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +struct ComputePtrOffsetOfStridedBatch> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + ck::index_t BatchStrideA_; + ck::index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +constexpr static auto GetNumABTensors() +{ + if constexpr(isTuple) + { + return Number{}; + } + else + { + return Number<1>{}; + } +} + +template +constexpr static auto GetAGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::AsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto GetBGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::BsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto UnpackDataType() +{ + if constexpr(isTuple) + { + // unpack if tuple + return tuple_element_t{}; + } + else + { + // if no, return Type + return Type{}; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 0190b3cee6b44af7656c34be866c9ab0855f7b0d..6f7d7c389404610ed7b1511ffc7656ed7f9f3d1c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -39,9 +39,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -668,26 +667,24 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56cc8fb7525ca352d3c85610acfd6be459f7d0ff --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -0,0 +1,842 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M * N * K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + if(!GridwiseGemm:: + template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 41445aeaf076408adc7159359884803f0965e313..abee2fea5368bae2e1dc8990b1643295f320b835 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -35,7 +35,7 @@ __global__ void const index_t group_count) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1; bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); @@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK 1); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..52aeefa3a41af3b4032798344790aaa9fa0150e3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Image to column: +// input : input image [G, N, Di, Hi, Wi, C] +// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C] +// input : input image [N, Di, Hi, Wi, G, C] +// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceImageToColumnImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + constexpr index_t spatial_offset = 3; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K( + a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& gemm_g_m_k_strides) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2])); + return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + + in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + out_grid_desc_m_k_ = MakeOutDescriptor_M_K( + N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides); + + compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0]; + } + + void Print() const + { + std::cout << in_grid_desc_m_k_ << std::endl; + std::cout << out_grid_desc_m_k_ << std::endl; + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + InputGridDesc in_grid_desc_m_k_; + OutputGridDesc out_grid_desc_m_k_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_, + arg.p_in_, + arg.out_grid_desc_m_k_, + arg.p_out_, + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_, + arg.out_grid_desc_m_k_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceImageToColumn" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp similarity index 93% rename from include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp index 175994d49f3a3121e2e32a4ca2ef07e60c0e52e5..e98a85defe95aabed5c977250048cae117b701d6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp @@ -8,7 +8,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -25,7 +25,7 @@ template -struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd +struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd { using DInDataType_AutomicAddPreCast = conditional_t || is_same_v, @@ -91,7 +91,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd& window_lengths, - const std::vector& window_strides) + const std::vector& window_strides, + const std::vector& window_dilations) : p_dout_{p_dout}, p_indices_{p_indices}, p_din_{p_din}, @@ -102,7 +103,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd window_strides.at(i); + auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1; + windowOverlap_ |= eff > window_strides.at(i); } } @@ -228,6 +230,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd window_lengths, - std::vector window_strides) override + std::vector window_strides, + std::vector window_dilations) override { // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are // physical size of the packed tensor @@ -302,7 +310,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86689af0b7021a8946717b5e7a1ac23559783aa2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { +template +__global__ void +kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M_K dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) +{ + GridwiseNormalizationBwd::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dx_grid_desc_m_k, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_gamma_global, + p_mean_global, + p_inv_std_global, + p_dx_global); +}; + +template +struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) || + (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto Make2dDescriptor(const std::vector& lengths, + const std::vector& strides, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(strides, Number{}); + + const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(lengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + + return transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, pad_M), + make_right_pad_transform(reduceLength, pad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k_padded; + } + + using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)); + + using GridwiseNormalizationBwdDataGeneric = + GridwiseNormalizationBwdData_mk_to_mk; + + using GridwiseNormalizationBwdDataSweepOnce = + GridwiseNormalizationBwdData_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const GammaDataType* p_gamma, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DXDataType* p_dx) + : p_dy_(p_dy), + p_x_(p_x), + p_gamma_(p_gamma), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dx_(p_dx) + { + lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + dxStrides_ = shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); + dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); + + isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DXDataType* p_dx_; + + std::vector lengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector dxStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // tensor descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + GridDesc_M_K dx_grid_desc_m_k_; + + bool isSweeponce_; + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + auto KernelSelector(bool isSweepOnce) + { + return isSweepOnce + ? kernel_normalization_bwd_data + : kernel_normalization_bwd_data; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = KernelSelector(arg.isSweeponce_); + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dx_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_gamma_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dx_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dyStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->xStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->gammaStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->meanStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->invStdStrides_); + + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dxStrides_); + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) override + { + if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + gammaStrides.size() != Rank || meanStrides.size() != Rank || + invStdStrides.size() != Rank || dxStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + dyStrides, + xStrides, + gammaStrides, + meanStrides, + invStdStrides, + dxStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dx)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize; + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c35652bfe860aeb56f4d5621aed2a312c78282d6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M dgamma_grid_desc_m, + const GridDesc_M dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) +{ + GridwiseReduction::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dgamma_grid_desc_m, + dbeta_grid_desc_m, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_mean_global, + p_inv_std_global, + p_dgamma_global, + p_dbeta_global); +}; + +template +struct DeviceNormalizationBwdGammaBetaImpl + : public DeviceNormalizationBwdGammaBeta +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor(inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_grid_desc_m_k_padded; + } + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1})); + + using GridwiseNormalizationBwdGammaBeta = + GridwiseNormalizationBwdGammaBeta_mk_to_k; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DGammaDataType* p_dgamma, + DBetaDataType* p_dbeta) + : p_dy_(p_dy), + p_x_(p_x), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dgamma_(p_dgamma), + p_dbeta_(p_dbeta), + outLengths_{outLengths}, + dgammaStrides_{dgammaStrides}, + dbetaStrides_{dbetaStrides} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(inLengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_); + + dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_); + dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_); + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DGammaDataType* p_dgamma_; + DBetaDataType* p_dbeta_; + + std::vector inLengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector outLengths_; + std::vector dgammaStrides_; + std::vector dbetaStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // Source descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + + // Destination descriptor + GridDesc_M dgamma_grid_desc_m_; + GridDesc_M dbeta_grid_desc_m_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + kernel_normalization_bwd_gamma_beta; + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dgamma_grid_desc_m_, + arg.dbeta_grid_desc_m_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dgamma_, + arg.p_dbeta_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsSrcVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + template + bool IsDstVectorSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(DstVectorSize == 1) + return true; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % DstVectorSize != 0) + return false; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->dyStrides_); + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->xStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->meanStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->invStdStrides_); + + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dgammaStrides_); + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dbetaStrides_); + + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) override + { + if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + meanStrides.size() != Rank || invStdStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim || + dbetaStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(inLengths, + dyStrides, + xStrides, + meanStrides, + invStdStrides, + outLengths, + dgammaStrides, + dbetaStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dgamma), + static_cast(p_dbeta)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ; + str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp similarity index 72% rename from include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp index ea0d805043e07fe3d818fbb7bf7ef42e952496a9..fa7b9cbda02334449d2af14561262cdf51f69356 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp @@ -7,7 +7,7 @@ #include #include "ck/utility/reduction_operator.hpp" -#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" @@ -19,7 +19,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -28,6 +28,7 @@ template -struct DeviceNormalizationImpl : public DeviceNormalization +struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd { static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert( @@ -64,18 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization& inLengths, const std::vector& inStrides, int numBlockTileIteration) { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -133,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); struct Argument : public BaseArgument { @@ -142,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, const BetaDataType* p_beta, - YDataType* p_y) + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) : p_x_(p_x), p_gamma_(p_gamma), p_beta_(p_beta), p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), y_elementwise_op_(y_elementwise_op) { epsilon_ = static_cast(epsilon); @@ -162,16 +206,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization(yStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; - long_index_t invariant_length; - long_index_t reduce_length; - - std::tie(invariant_length, reduce_length) = - get_2d_lengths(Lengths_); + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); - numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize); + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); - gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); gamma_grid_desc_m_k_ = @@ -179,9 +221,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}) <= KThreadClusterSize * KThreadSliceSize; + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; } ComputeDataType epsilon_; @@ -190,12 +239,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization Lengths_; std::vector xStrides_; std::vector gammaStrides_; std::vector betaStrides_; std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; YElementwiseOperation y_elementwise_op_; @@ -206,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization(arg.isSweeponce_); float avg_time = 0; @@ -245,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization(p_arg); - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - if constexpr(XYSrcVectorDim == 0) { if constexpr(NumInvariantDim == 0) @@ -280,10 +345,10 @@ struct DeviceNormalizationImpl : public DeviceNormalizationxStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) return false; - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) return false; }; } @@ -325,7 +390,7 @@ struct DeviceNormalizationImpl : public DeviceNormalizationbetaStrides_[NumInvariantDim - 1] != 1) return (false); - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) return (false); } else // if fastest dim is reduced @@ -337,6 +402,9 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + return true; }; @@ -346,6 +414,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -353,27 +423,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, betaStrides, yStrides, + saveMeanStrides, + saveInvStdStrides, reduceDims, y_elementwise_op, epsilon, static_cast(p_x), static_cast(p_gamma), static_cast(p_beta), - static_cast(p_y)); + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); }; std::unique_ptr MakeInvokerPointer() override @@ -386,7 +459,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization @@ -28,8 +28,8 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType* const __restrict__ p_x_global, - MeanVarDataType* const __restrict__ p_welford_mean, - MeanVarDataType* const __restrict__ p_welford_variance, + WorkspaceMeanVarDataType* const __restrict__ p_welford_mean, + WorkspaceMeanVarDataType* const __restrict__ p_welford_variance, int32_t* const __restrict__ p_welford_count) { GridwiseWelford::Run(x_grid_desc_m_k, @@ -42,16 +42,18 @@ kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, }; template + typename XYGammaBetaGridDesc_M_K, + typename SaveMeanInvStdGridDesc_M> __global__ void kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, @@ -59,17 +61,21 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, - const MeanVarDataType* const p_mean_global, - const MeanVarDataType* const p_variance_global, + const WorkspaceMeanVarDataType* const p_mean_global, + const WorkspaceMeanVarDataType* const p_variance_global, const int32_t* const p_welford_count_global, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, @@ -78,6 +84,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_mean_var_count_iteration, num_k_block_tile_iteration, k_grid_size, @@ -89,6 +97,8 @@ kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_ p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; } // namespace ck @@ -98,7 +108,7 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) -// M: Invarient length +// M: Invariant length // K: Reduce length (Calculate mean and variance along K dimension) // eg. Length = [N, C, H, W], reduce dim = [C, H, W] // Then, M = N, K = C * H * W @@ -107,6 +117,7 @@ template -struct DeviceNormalizationSplitKImpl : public DeviceNormalization + index_t YDstVectorSize, + index_t SaveMeanInvStdDstVectorSize> +struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd { - using MeanVarDataType = ComputeDataType; + using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert( @@ -144,22 +156,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization{}; static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + static auto MakeSrc2dDescriptor(const std::vector& inLengths, const std::vector& inStrides, int kBlockSize, int numBlockTileIteration) { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -219,7 +237,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization - static auto MakeMeanVarDescriptor_M_K(index_t M, index_t K) + static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K) { const auto grid_desc_m_k = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); @@ -227,26 +245,57 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization - static auto MakeCountDescriptor_M_K(index_t M, index_t K) + static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K) { const auto grid_desc_m_k = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); } + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using Kernel1MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); using Kernel2MeanVarGridDesc_M_KBlock = - decltype(MakeMeanVarDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); using Kernel2CountGridDesc_M_KBlock = - decltype(MakeCountDescriptor_M_K, 1, 1>(1, 1)); + decltype(MakeWorkspaceCountDescriptor_M_K, 1, 1>(1, 1)); + + using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); using GridwiseWelford = GridwiseNormalizationSplitK1st; using GridwiseWelfordNormalization = - GridwiseNormalizationSplitK2nd; + YDstVectorSize, + SaveMeanInvStdDstVectorSize>; struct Argument : public BaseArgument { @@ -289,17 +341,23 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, const BetaDataType* p_beta, - YDataType* p_y) + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) : p_x_(p_x), p_gamma_(p_gamma), p_beta_(p_beta), p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), p_workspace_mean_{nullptr}, p_workspace_var_{nullptr}, p_workspace_count_{nullptr}, @@ -312,6 +370,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(yStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); @@ -346,20 +406,28 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization, M_BlockTileSize, 1>(MRaw_, - kGridSize_); + MakeWorkspaceMeanVarDescriptor_M_K, M_BlockTileSize, 1>( + MRaw_, kGridSize_); kernel2_mean_var_grid_desc_m_kblock_ = - MakeMeanVarDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + MakeWorkspaceMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); kernel2_count_grid_desc_m_kblock_ = - MakeCountDescriptor_M_K, - M_BlockTileSize, - K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + MakeWorkspaceCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; } ComputeDataType epsilon_; @@ -368,6 +436,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides_; std::vector betaStrides_; std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; YElementwiseOperation y_elementwise_op_; @@ -389,13 +461,17 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization; auto kernel2 = kernel_normalizationSplitK2nd; + SrcGridDesc_M_K, + SaveMeanInvStdGridDesc_M>; float avg_time = 0; - avg_time += launch_and_time_kernel(stream_config, - kernel1, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.kernel1_mean_var_grid_desc_m_kblock_, - arg.numBlockTileIteration_, - arg.p_x_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_)); - - avg_time += launch_and_time_kernel(stream_config, - kernel2, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.kernel2_mean_var_grid_desc_m_kblock_, - arg.kernel2_count_grid_desc_m_kblock_, - arg.x_grid_desc_m_k_, - arg.gamma_grid_desc_m_k_, - arg.beta_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - arg.numMeanVarCountIteration_, - arg.numBlockTileIteration_, - arg.kGridSize_, - arg.epsilon_, - static_cast(arg.p_workspace_mean_), - static_cast(arg.p_workspace_var_), - static_cast(arg.p_workspace_count_), - arg.p_x_, - arg.p_gamma_, - arg.p_beta_, - arg.p_y_, - arg.y_elementwise_op_); + avg_time += launch_and_time_kernel( + stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); return avg_time; }; @@ -482,10 +566,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationMRaw_ * pArg_->kGridSize_; // workspace for welford intermediate mean - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; // workspace for welford intermediate variance - workspace_size += welford_size * sizeof(MeanVarDataType) + 64; + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; // workspace for welford intermediate count workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; @@ -493,7 +577,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(pArg); @@ -504,13 +590,13 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationp_workspace_mean_ = static_cast(pArg_->p_workspace_); - index_t mean_space_sz = welford_size * sizeof(MeanVarDataType); + index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); // setup buffer used for intermediate welford varirance pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; - index_t variance_space_sz = welford_size * sizeof(MeanVarDataType); + index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); // setup buffer used for intermediate welford count @@ -522,8 +608,6 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(p_arg); - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - if constexpr(XYVectorDim == 0) { if constexpr(NumInvariantDim == 0) @@ -535,10 +619,10 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationxStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) return false; - if(p_arg_->invariant_lowest_length % YDstVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) return false; }; } @@ -578,7 +662,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationbetaStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) return false; } else // if fastest dim is reduced @@ -593,6 +677,9 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalizationkGridSize_ <= 1) return false; + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + return true; }; @@ -602,6 +689,8 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -609,27 +698,30 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, betaStrides, yStrides, + saveMeanStrides, + saveInvStdStrides, reduceDims, y_elementwise_op, epsilon, static_cast(p_x), static_cast(p_gamma), static_cast(p_beta), - static_cast(p_y)); + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); }; std::unique_ptr MakeInvokerPointer() override @@ -642,7 +734,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization(x1); }; + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const @@ -68,6 +75,15 @@ struct Add y = ck::type_convert(y_tmp); } + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } + template <> __host__ __device__ constexpr void operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const @@ -78,10 +94,13 @@ struct Add struct ScaleAdd { - __host__ __device__ ScaleAdd(float scale) : scale_(scale) {} + __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} template - __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + y = ck::type_convert(scale_ * ck::type_convert(x0) + ck::type_convert(x1)); + } template <> __host__ __device__ void @@ -146,7 +165,7 @@ struct Subtract struct Bilinear { - Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; @@ -165,6 +184,14 @@ struct Bilinear y = alpha_ * x0 + beta_ * x1; }; + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const @@ -179,6 +206,33 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x0_tmp = type_convert(x0); + const float x1_tmp = type_convert(x1); + const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp; + y = type_convert(y_tmp); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; + y = y_tmp; + }; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + float alpha_; float beta_; }; @@ -228,6 +282,14 @@ struct AddRelu y = a > 0.0f ? a : 0.0f; }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + template <> __host__ __device__ constexpr void operator()(int& y, const int& x0, const int8_t& x1) const @@ -318,6 +380,70 @@ struct AddFastGelu e = type_convert(x1_f); } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = Silu(C + D) +struct AddSilu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 3fdb391a0215e83328905b77148b675defc27af7..7fdd6448b6ca11187c6d873f2838f55fb24cde64 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -195,6 +195,51 @@ struct AddMultiply } }; +// C = A * B +// E = C x D0 + D1 +struct MultiplyAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c * d0) + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = c * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const float& d0, + const float& d1) const + { + const float y = c * d0 + d1; + e = y; + } +}; + // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { @@ -266,6 +311,71 @@ struct AddAddFastGelu } }; +// E = Relu(alpha1 * C + alpha2 * D0 + D1) +struct ScaleAddScaleAddRelu +{ + + ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f) + : alpha1_(alpha1), alpha2_(alpha2) + { + } + + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x = c * alpha1_ + alpha2_ * d0 + d1; + Relu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + Relu{}.template operator()(result, x); + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + Relu{}.template operator()(result, x); + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int8_t& c, const float& d0, const float& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * d0 + d1; + + float result = 0; + Relu{}.template operator()(result, x); + + e = type_convert(result); + } + + const float alpha1_; + const float alpha2_; +}; + struct Normalize { // FIXME: is double absolutely necessary? diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c47822499c870a3a1cd33a3bc11adf4fcf5ef65f..dc30ceb13d65f954374e29425dfa4c4b3435692f 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,18 @@ namespace element_wise { extern "C" __device__ float __ocml_native_recip_f32(float); #endif +struct PassThroughPack2 +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const + { + auto t = type_convert(x); + y = type_convert(t); + } +}; + struct PassThrough { template @@ -27,6 +39,18 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(double& y, const float& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -39,6 +63,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const float& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -57,30 +87,65 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const bhalf_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const half_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const int8_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const { y = x; } + template <> + __host__ __device__ void operator()(int4_t& y, const int& x) const + { + y = type_convert(x); + } #endif template <> @@ -112,6 +177,36 @@ struct PassThrough { y = type_convert(x); } + + template <> + __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const half_t& x) const + { + y = ck::type_convert(x); + } }; struct UnaryConvert @@ -147,7 +242,8 @@ struct ConvertF8SR __host__ __device__ void operator()(Y& y, const X& x) const { // check Y datatype - static_assert(is_same::value, "Data type is not supported by this operation!"); + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); // check X datatype static_assert(is_same::value || is_same::value, @@ -157,6 +253,24 @@ struct ConvertF8SR } }; +struct ConvertF8RNE +{ + // convert to fp8 using rounding to nearest even + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_rne(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} @@ -164,6 +278,20 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const; + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = ck::type_convert(scale_) * x; + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + const float x_tmp = ck::type_convert(x); + const float y_tmp = scale_ * x_tmp; + y = ck::type_convert(y_tmp); + }; + template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -217,8 +345,8 @@ struct UnarySquare template __host__ __device__ void operator()(T& y, const T& x) const { - static_assert(is_same_v || is_same_v || is_same_v || - is_same_v + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 || is_same_v #endif @@ -292,27 +420,29 @@ struct FastGelu template <> __host__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - - y = x * cdf; + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); } #endif // device code, use lower precision "__expf" and "rcp" template <> __device__ void operator()(float& y, const float& x) const { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = __expf(-u); + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __expf(u); #if !CK_WORKAROUND_SWDEV_383542 - const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); + y = x * __frcp_rn(1.f + emu); #else - const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); + y = x * __ocml_native_recip_f32(1.f + emu); #endif - - y = x * cdf; } template <> @@ -383,10 +513,24 @@ struct Sigmoid __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + ck::math::exp(-x)); + }; +}; - y = 1 / (ck::type_convert(1) + exp(-x)); +struct Silu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = x * (one / (one + ck::math::exp(-x))); }; }; @@ -396,7 +540,8 @@ struct TanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); y = ck::math::tanh(x); @@ -407,17 +552,116 @@ struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + }; + + const float beta_; +}; + +struct SoftRelu +{ + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + template __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + const float alpha_; +}; - y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); - }; +struct Power +{ + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + const float alpha_; + const float beta_; + const float gamma_; +}; + +struct ClippedRelu +{ + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; - float beta_ = 1.0f; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + const float alpha_; + const float beta_; +}; + +struct LeakyRelu +{ + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + const float alpha_; +}; + +struct Elu +{ + Elu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + const float alpha_; }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index e71ade77832ca6e6db873e60b33725a20430a65e..72db56b2a531d25b91cf644c734068ac00fd4e3d 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -138,6 +138,11 @@ struct BlockToCTileMap_M00_N0_M01Adapt __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif } template @@ -259,6 +264,302 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt( + const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& + operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, + index_t N, + index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif + } + + template + __host__ __device__ + BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) + : BlockToCTileMap_Grouped_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + auto group_id = block_1d_id % GroupNum; + auto remap_block_1d_id = group_id * group_size + block_1d_id / GroupNum; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt + : BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + using BlockToCTileMap_Grouped_M00_N0_M01Adapt:: + BlockToCTileMap_Grouped_M00_N0_M01Adapt; +}; + +// columns of row-vectors +// This C-tile map dynamically adjusts N01 when C-tile index is out of range +template +struct BlockToCTileMap_N00_M0_N01Adapt; + +template +struct BlockToCTileMap_N00_M0_N01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8) + : M_(M), N_(N), N01_(N01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_); + } +#endif + } + + template + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t N01 = 8) + : BlockToCTileMap_N00_M0_N01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01) + { + } + + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_M0 = block_1d_id % M0; + index_t idx_N0 = block_1d_id / M0; + + const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_; + + index_t idx_N00 = idx_N0 / N01_; + index_t idx_N01 = idx_N0 % N01_; + index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0; + + /** + * idxN0 + * + * |< mtx N >| + * + * |<---N01--->| + * - |-----------|-----------|-----------|-----|-----|- + * ^ | 0 ----------> 1 | | | | + * | | / | | | | M_0 MPerBlock + * | / | | | | + * |------/----------------|-----------|-----|-----|- + * | | | | | | | + * idxM0 | V | | | | | M_1 MPerBlock + * | 2 ----------> 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | blockid | | | | + * | | 5 | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * Example: + * assume: + * N0 = 5 + * M0 = 4 + * block_1d_id = 5 + * N01 = 2 + * + * idx_M0 = 1 + * idx_N0 = 1 + * N01_adapt = 2 + * idx_N00 = 0 + * idx_N01 = 1 + * idx_M0_N01_local = 5 + * output {2, 1} + */ + + return make_tuple(idx_M0_N01_local / N01_adapt, + idx_M0_N01_local % N01_adapt + idx_N00 * N01_); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t N01_; +}; + // 2D slices of column-vectors in 3D space // This C-tile map dynamically adjusts M01 when C-tile index is out of range template @@ -594,7 +895,8 @@ struct OffsettedBlockToCTileMap { using underlying_type = UnderlyingBlockToCTileMap; - OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t block_start) { block_to_ctile_map_ = block_to_ctile_map; block_start_ = block_start; diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index b25f136a3713a412fbcb166f3df304a172ee42f7..206ea00b9d0a94dcd5f12c02948d8f84f9223fb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -522,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index c2f47bd44435bff2ea75cbbee625895cb0c1bef4..9469fa7bc7b27351f393caed43a9bc12a2b8780d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -628,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index d2920570e4f2cefc883aa9bf4bf1d8b845146670..a0924ae3b0525b6605813a8ac58c40cd7ce4c18d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -880,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{} + Gemm1KPack * XdlopsGemm{} .K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 18cfeebcf30d5ea8e44c9a5a5ff1cf564fef01ba..bc76d4cc4fb9e905cdb1547273866b66575ff334 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -794,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index f4b82badf14584d5e59fda6dc6751a83aa67b32f..afb2ad2e760396c200930254f871ff13e032a91a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -649,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp new file mode 100644 index 0000000000000000000000000000000000000000..48ae489f42d7045875b4e080543e0ebfc4f849cf --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) +{ + GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + unary_op, + scale_op); +} + +template +struct GridwiseElementwise_1D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid1dDescTuple::Size() && + NumOutput == OutGrid1dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) + { + const index_t thread_global_id = get_thread_global_1d_id(); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At( + I), // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc_tuple[I], + thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + false>( + out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter = M / (loop_step); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I], + loop_step_index); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + // get reference to in data + auto uop_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(unary_op, uop_data_refs, uop_data_refs); + + auto sop_in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + auto sop_out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(scale_op, sop_out_data_refs, sop_in_data_refs); + + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_m, + make_tuple(I0), + out_thread_buf_tuple[I], + out_grid_1d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I], + loop_step_index); + }); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..242996019b280cc0fc29e9ce2522817ac7a43513 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_3d(const InGrid3dDescTuple in_grid_3d_desc_tuple, + const OutGrid3dDescTuple out_grid_3d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n, + const index_t num_threads_k) +{ + GridwiseElementwise3dFunctor::Run(in_grid_3d_desc_tuple, + out_grid_3d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + num_threads_m, + num_threads_n, + num_threads_k); +} + +template +struct GridwiseElementwise_3D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid3dDescTuple::Size() && + NumOutput == OutGrid3dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto thread_buffer_desc_mnk = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid3dDescTuple in_grid_3d_desc_tuple, + const OutGrid3dDescTuple out_grid_3d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const index_t num_threads_m, + const index_t num_threads_n, + const index_t num_threads_k) + { + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_3d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_3d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto M = in_grid_3d_desc_tuple[I0].GetLength(I0); + const auto N = in_grid_3d_desc_tuple[I0].GetLength(I1); + const auto K = in_grid_3d_desc_tuple[I0].GetLength(I2); + + const index_t loop_step_m = num_threads_m * MPerThread; + const index_t loop_step_n = num_threads_n * NPerThread; + const index_t loop_step_k = num_threads_k * KPerThread; + + const index_t thread_1d_id = get_thread_global_1d_id(); + + const index_t tid_m = thread_1d_id / (num_threads_n * num_threads_k); + const index_t tid_nk = thread_1d_id % (num_threads_n * num_threads_k); + const index_t tid_n = tid_nk / num_threads_k; + const index_t tid_k = tid_nk % num_threads_k; + + const auto thread_global_offset = + make_multi_index(tid_m * MPerThread, tid_n * NPerThread, tid_k * KPerThread); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2< + DataType, + DataType, + decltype(in_grid_3d_desc_tuple[I]), + decltype(thread_buffer_desc_mnk), + Sequence, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 01, // SrcVectorDim + InScalarPerVectorSeq::At(I), // InScalarPerVectorSeq::At(I), // + // ScalarPerVector + 1, // SrcScalarStrideInVector + true>{in_grid_3d_desc_tuple[I], thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3< + DataType, + DataType, + decltype(thread_buffer_desc_mnk), + decltype(out_grid_3d_desc_tuple[I]), + PassThroughOp, + Sequence, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + OutScalarPerVectorSeq::At(I), // OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + true>(out_grid_3d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter_m = M / (loop_step_m); + do + { + index_t num_iter_n = N / (loop_step_n); + do + { + index_t num_iter_k = K / (loop_step_k); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_3d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_mnk, + make_tuple(I0, I0, I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow( + in_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k)); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + static_for<0, NPerThread, 1>{}([&](auto iN) { + static_for<0, KPerThread, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_mnk.CalculateOffset(make_tuple(iM, iN, iK)); + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(I)(Number{}); + }, + Number{}); + + // get referenec to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { + return out_thread_buf_tuple(I)(Number{}); + }, + Number{}); + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + }); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_mnk, + make_tuple(I0, I0, I0), + out_thread_buf_tuple[I], + out_grid_3d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow( + out_grid_3d_desc_tuple[I], make_multi_index(0, 0, loop_step_k)); + }); + } while(--num_iter_k); + + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).MoveSrcSliceWindow( + in_grid_3d_desc_tuple[I], + make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k)); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).MoveDstSliceWindow( + out_grid_3d_desc_tuple[I], + make_multi_index(0, loop_step_n, -(K / loop_step_k) * loop_step_k)); + }); + + } while(--num_iter_n); + + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).MoveSrcSliceWindow( + in_grid_3d_desc_tuple[I], + make_multi_index(loop_step_m, + -(N / loop_step_n) * loop_step_n, + -(K / loop_step_k) * loop_step_k)); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).MoveDstSliceWindow( + out_grid_3d_desc_tuple[I], + make_multi_index(loop_step_m, + -(N / loop_step_n) * loop_step_n, + -(K / loop_step_k) * loop_step_k)); + }); + } while(--num_iter_m); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index 3ea72b85345869a938e52f872c2c586c6e36170f..072275c089ce469a1d7e043799e290a28c0b310c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, - XDataType* const __restrict__ p_x_lds, + XDataType* const __restrict__ p_x_lds_, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, @@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); auto x_lds_val_buf = make_dynamic_buffer( - p_x_lds, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); auto in_thread_buf_tuple = generate_tuple( [&](auto) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index 3ced4b9ad61fe0be2783fb6f261b9131fb8d1622..35176591c1cfdd4ed0037b5baf24965660e4a0ee 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -67,7 +67,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -504,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1d1bb6ed2d446fc61f4056acea275e048f002997..1da7236978eb042c4270179dcc502ea3593021a2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -7,11 +7,9 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" @@ -19,8 +17,6 @@ namespace ck { -using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm; - template + bool HasDoubleTailKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -43,13 +38,6 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -// DPP8 is currently only supported on gfx1030 -#if !defined(__gfx1030__) - if(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return; - } -#endif constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -100,8 +88,7 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 c_grid_desc_m_n); } - template - __host__ __device__ static constexpr auto GetBlockwiseGemm() - { - if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - else - { - return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - } - using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = @@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - GetBlockwiseGemm(); + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b473d7cbf2b73070eacc38cce27b419e52adb179 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b7cc5679657ffd31dde5705391119630d1f8c3d --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1040 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A0[M, K], A1[M, K] +// input : B0[N, K], B1[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleABD_xdl_cshuffle +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using ComputeDataType = + conditional_t, ck::bhalf_t, ComputeDataType_>; +#else + using ComputeDataType = ComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + Number{}); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + Number{}); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k, + const BsGridDesc_N_K& bs_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = as_grid_desc_m_k[I0].GetLength(I0); + const auto N = bs_grid_desc_n_k[I0].GetLength(I0); + const auto AK = as_grid_desc_m_k[I0].GetLength(I1); + const auto BK = bs_grid_desc_n_k[I0].GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + bool valid = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + valid = + valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB); + valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) && + AK == as_grid_desc_m_k[i].GetLength(I1)); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + valid = + valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB); + valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) && + BK == bs_grid_desc_n_k[i].GetLength(I1)); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + + if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride) + { + return generate_tuple( + [&](auto i) { + using ALayout = remove_cvref_t>; + + return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array& KRaws, + const std::array& NRaws, + const std::array& BsStride) + { + return generate_tuple( + [&](auto i) { + using BLayout = remove_cvref_t>; + + return MakeBGridDescriptor_N_K(KRaws[i], NRaws[i], BsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + static_assert(ABlockTransferSrcScalarPerVector == ABlockTransferDstScalarPerVector_AK1, + "Src and Dst ScalarPerVector must be the same"); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + static_assert(BBlockTransferSrcScalarPerVector == BBlockTransferDstScalarPerVector_BK1, + "Src and Dst ScalarPerVector must be the same"); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeDataType, // ComputeDataType for A + ComputeDataType, // ComputeDataType for B + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + gridwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const std::array StrideAs, + const std::array StrideBs, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + using AsGridDesc_M_K = + remove_cvref_t({}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t({}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + AsGridDesc_M_K as_grid_desc_m_k; + BsGridDesc_N_K bs_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumATensor, 1>{}([&](auto j) { + using ALayout = remove_cvref_t>; + + as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K(M, K, StrideAs[j]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto j) { + using BLayout = remove_cvref_t>; + + bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K(N, K, StrideBs[j]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + + const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index d710fc18944ddb538ffb2d0385db862958e0542f..5c9f40b51a5ba28810352cf9dcf6d6e2dca1a404 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -470,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 98ade85a3089ccc5638fb45d4270488ebab59f3a..f514e3a1198601c263d1a9ccd9d992b152986e40 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -36,7 +36,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + kernel_grouped_conv_multiple_d_wmma_cshuffle( const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, @@ -54,8 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -148,8 +147,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; @@ -244,8 +242,7 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; GridwiseOp::template Run(p_a_grid, @@ -274,7 +271,7 @@ __global__ void ignore = b_element_op; ignore = cde_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__ )) +#endif // end of if (defined(__gfx11__ )) } template < // DataType Family @@ -452,11 +449,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + // CheckValidity for kernels without multi D template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map) { @@ -471,18 +468,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - bool valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && - N == ds_grid_desc_m_n[i].GetLength(I1)); - }); - - if(!valid) - { - return false; - } - if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K1 == b_grid_desc_k0_n_k1.GetLength(I2))) @@ -517,6 +502,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle return true; } + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + bool valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + return CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map); + } + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 3b8a5ec8f50a400af8b317e534a32c8113a1018f..15c30a0dadc51334a6463f80d39ba8c440ad363f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -15,6 +15,9 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + namespace ck { // GEMM: @@ -26,7 +29,9 @@ namespace ck { // E = cde_op(C, D0, D1, ...) // Assume: // D0, D1, ... and E have the same layout -template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -92,15 +100,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using GridwiseGemmPipe = remove_cvref_t< decltype(GridwiseGemmPipeline_Selector())>; - // denorm test fix, required to work around fp16 mfma issue - // we convert fp16->fp32->bf16 and execute bf16 mfma instruction - // when mfma if fixed, remove this section and update - // ABDataTypeAdjusted -> ABDataType throughout this file #if CK_WORKAROUND_DENORM_FIX - using ABDataTypeAdjusted = - conditional_t, ck::bhalf_t, ABDataType>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using ABDataTypeAdjusted = ABDataType; + using AComputeDataType = AComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -169,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -265,6 +269,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); const auto M = a_grid_desc_m_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0); @@ -313,8 +319,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // check tensor size: cannot be larger than 2GB each constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && - b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -332,14 +338,102 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using DsGridPointer = decltype(MakeDsGridPointer()); + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + template - __device__ static void Run(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, @@ -408,8 +502,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + ADataType, + AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -439,8 +533,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + BDataType, + BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -468,13 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataTypeAdjusted, + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -492,11 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -761,6 +856,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle }); } } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K(K, N, StrideB); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd36b9e51ae6db6c74f491b253ae7099a259d747 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -0,0 +1,1000 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, destination of blockwise copy. + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, destination of blockwise copy. + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle. + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max( + NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) + + NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // A desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy. + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using Block2ETileMap = remove_cvref_t; + + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + static_assert( + std::is_same_v && + std::is_same_v, + "Direct load transfers do not support elementwise operations other than passthrough."); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); + + // Check the consistency of descriptors. + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // Check the tile size. + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // Check gridwise gemm pipeline. + const auto num_k_loop = AK / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // Check block-to-E-tile. + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // Check tensor size: cannot exceed 2GB. + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // Divide block work by [M, N]. + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // This forces m/n_block_data_idx_on_grid into SGPR. + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, destination of blockwise copy. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, destination of blockwise copy. + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ADataType, + AComputeDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcVectorDim, + 2, + ABlockTransferScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BDataType, + BComputeDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcVectorDim, + 2, + BBlockTransferScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // Shuffle C and write out. + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // Calculate the origin of thread output tensor on global memory. + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // Shuffle: threadwise copy C from VGPR to LDS. + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // A tuple of reference to C/Ds tensor descriptors. + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // A tuple of reference to C/Ds grid buffers. + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // A tuple of starting index of C/Ds blockwise copy. + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // Blockwise copy C/D/E between LDS and global. + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // Space filling curve for threadwise C in VGPR before shuffle. + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // Space filling curve for shuffled blockwise C/D/E. + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // Make sure it's safe to write to LDS. + block_sync_lds(); + + // Each thread write its data from VGPR to LDS. + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // Make sure it's safe to read from LDS. + block_sync_lds(); + + // Each block copy its data from LDS to global. + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // Move on Ds. + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // Move on E. + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + struct Argument : public tensor_operation::device::BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + if(CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // Tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise ops + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // For checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4f37049b203b9f716e50fef2d13b72b2e345aba4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -0,0 +1,1077 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ComputeType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 7c11316cefa8058b5a127c316ddad5d6d7a5fad4..ee5a3e91e3e2c9c373458f6dba1f24f81ac93a8c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -3,18 +3,22 @@ #pragma once -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #ifndef __HIPCC_RTC__ #include #endif +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" + namespace ck { enum struct PipelineVersion { v1, v2, + // v3 is only used in the Stream-K implementation. + v4, }; template {}; + } else { #ifndef __HIPCC_RTC__ diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636de93a13cbe7f851575b4abaedd99af72..754a3e89c9327792fb370e30fbacdee2b9758163 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index d3d7d5af8577d1c2ed65fa817f708379c6089a25..25e1cebdbe06918cfea98132209116a2d74de822 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -9,13 +9,13 @@ namespace ck { struct GridwiseGemmPipeline_v2 { - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) { // TODO: improve applicability return num_loop % 2 == 0; } - __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) { return (num_loop / 2) > 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08d986d0da62c01c970ee10ceab6b8e000392b46 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace lds_direct_load { + +__device__ void sched_barrier() +{ +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of + // `sched_barrier` is necessary to make sure no instructions that use the loaded memory + // are scheduled by the compiler before the `waitcnt` instruction. + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace lds_direct_load + +namespace ck { + +template +struct GridwiseGemmPipeline_v4; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v4<1> +{ + static constexpr auto I0 = Number<0>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); + auto& a_block_buf = a_block_bufs.At(I0); + auto& b_block_buf = b_block_bufs.At(I0); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stages prefetch +template <> +struct GridwiseGemmPipeline_v4<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); + auto& a_block_buf1 = a_block_bufs.At(I0); + auto& a_block_buf2 = a_block_bufs.At(I1); + auto& b_block_buf1 = b_block_bufs.At(I0); + auto& b_block_buf2 = b_block_bufs.At(I1); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 99e410f688312b47dbeba4afb1d1712624561a76..0e5777e561476da30681b67f6d0d85a152a683f5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -55,7 +55,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -457,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 18cf80041b71c9f581462cef10f7fd3508e96930..e7dc0d3eb0d9296f4e9ca98e68d0cadfb3cbcfb1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -588,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -1012,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..caf8f040f4a78232386a1f0c3d0be02b4dc63295 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -0,0 +1,1076 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index d8b31311b1040f6c77e955550875a2a3e1337530..066cfc62f2e553e400314b00dd2d92ce11f61f6d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -49,8 +49,7 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ - defined(__gfx1102__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -75,7 +74,7 @@ __global__ void ignore = b_element_op; ignore = c_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx11__)) } template ( @@ -50,7 +50,7 @@ __global__ void typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); @@ -108,7 +108,8 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -547,8 +548,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ComputeType) + - b_block_space_size_aligned * sizeof(ComputeType)), + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), c_block_size * sizeof(FloatCShuffle)); } @@ -750,7 +751,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + ComputeTypeA, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -781,7 +782,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + ComputeTypeB, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -809,13 +810,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1Number, BK1Number), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, + ComputeTypeA, + ComputeTypeB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -833,10 +835,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..db9625c6e6032114b87209bfc47048f4c6d01b86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,1153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(num_k_loop < 4) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return num_loop > 3; + } + + __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + if(num_loop % 2 == 1) + return 3; + else + return 2; + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } +#if 0 + if(threadIdx.x == 0){ + printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n", + get_block_1d_id(), + block_work_idx[I0], + block_work_idx[I1], + __smid()>>6 & 0xf, + __smid()>>4 & 0x3, + __smid() & 0xf); + } +#endif + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + // BlockSize, + // ComputeType, + // FloatGemmAcc, + // decltype(a_block_desc_ak0_m_ak1), + // decltype(b_block_desc_bk0_n_bk1), + // MPerXdl, + // NPerXdl, + // MXdlPerWave, + // NXdlPerWave, + // KPack, + // LoopSched>(); + auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4< + BlockSize, + ComputeTypeA, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; // TransposeC + + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 0404d88ab8b4faebbe8613453cb4274cf631fe2b..7f815de1f9d3ffc145664e8781d6f7af20cbcde1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -58,7 +58,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue @@ -495,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index bbd01a238e5d5f5d071ddec519a08d68e8fef206..8675a9242acf0d573cc0316f3f89b95b654b38af 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -494,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, ABDataType, + ABDataType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 0920a17fc7f9cfa6265f1599d6deade702fd53ed..5617f67f8be20f9d775f157c6e24804c4a906241 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen } template (p_a_grid, @@ -181,21 +182,22 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight { static constexpr auto I0 = Number<0>{}; @@ -265,11 +269,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update - // FloatABAdjusted -> FloatAB throughout this file + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file #if CK_WORKAROUND_DENORM_FIX - using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatABAdjusted = FloatAB; + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; #endif // M0/M1/M1Padding @@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), c_block_size * sizeof(FloatC)); } @@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, @@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatA, + FloatAAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatB, + FloatBAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size, + static_cast(p_shared) + a_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index b12bcee0f414c6e56b6550c084630b86ab153bb5..7c401a49572a6f72c0e3a4c9d46d4bc4515f73cd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -45,7 +45,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94306a4c958b6183c299a04169b93301000080d7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -0,0 +1,962 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdlops_splitk_lds_direct_load +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * + sizeof(ComputeType), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_K0_M_K1, + FloatA, + ComputeType, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_K0_N_K1, + FloatB, + ComputeType, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index 4408b34870e93b2e681160fee12e0743269ff6f7..e9190dee292064fb1b5b31d249c88527718768ea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -37,7 +37,8 @@ __global__ void index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -489,6 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(p_a_grid, @@ -70,7 +70,7 @@ __global__ void kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_k0_m_k1 = @@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; } - __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } // Argument struct Problem @@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 StrideC{StrideC_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, - K0{CalculateK0(K)} + K0{CalculateK0(K_)} { } @@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 "Invalid tuning param!"); // check gridwise gemm pipeline - const index_t K0 = problem.K / K1Value; - const auto num_k_loop = K0 / K0PerBlock; - + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; @@ -383,7 +381,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -426,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( a_grid_desc_m_k, @@ -874,7 +892,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( b_grid_desc_k_n, @@ -908,7 +945,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { return transform_tensor_descriptor(c_grid_desc_m_n, make_tuple(make_right_pad_transform(M, MPad - M), @@ -958,6 +996,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.K0 % K0PerBlock == 0)) + { + return false; + } + } + if constexpr(is_same::value) { if(problem.K % ABlockTransferSrcScalarPerVector != 0) @@ -989,8 +1038,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } // check gridwise gemm pipeline - const index_t K0 = problem.K / K1; - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 19fbee727f77ae8b6b6f90a5a6f520accde95332..7d8e94c00166cb3933cdb9bef00dc7a352b7a8b8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -43,7 +43,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 371281dfe20dddacf861677a771c3a425028a27f..b52f5c51b1ad267f0b808e283a3fa40ba5de4532 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -22,25 +21,34 @@ namespace ck { template + typename Block2CTileMap, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map) + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; GridwiseGemm::template Run( - karg, static_cast(p_shared), b2c_map); + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); #else ignore = karg; ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -87,7 +95,10 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA, + typename LDSTypeA = ComputeTypeA, + typename LDSTypeB = ComputeTypeB> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -127,7 +138,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t MPadded; index_t NPadded; index_t KPadded; - index_t K0; + index_t K0Padded; index_t k_batch; Argument(const FloatA* p_a_grid_, @@ -142,7 +153,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t MPadded_, index_t NPadded_, index_t KPadded_, - index_t K0_, + index_t K0Padded_, index_t k_batch_) : p_a_grid(p_a_grid_), p_b_grid(p_b_grid_), @@ -156,7 +167,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MPadded(MPadded_), NPadded(NPadded_), KPadded(KPadded_), - K0(K0_), + K0Padded(K0Padded_), k_batch(k_batch_) { } @@ -173,7 +184,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KP:" << KPadded << ", " - << "K0:" << K0 << ", " + << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; @@ -196,7 +207,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return math::integer_least_multiple(N, NPerBlock); } - __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) { // k_batch * k0 * k0_per_block * k1 auto K_t = K_Batch * K0PerBlock * K1; @@ -205,8 +216,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) { - auto K0 = CalculateK0(K, K_Batch); - return K_Batch * K0 * K1; + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; } __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, @@ -214,7 +225,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t K, index_t StrideA, index_t KBatch, - index_t K0, + index_t K0Padded, index_t KPad) { const auto a_grid_desc_m_k = [&]() { @@ -228,30 +239,57 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } }(); - const auto a_grid_desc_m_kpad = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; return transform_tensor_descriptor( a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } - else + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -263,7 +301,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 index_t N, index_t StrideB, index_t KBatch, - index_t K0, + index_t K0Padded, index_t KPad) { const auto b_grid_desc_k_n = [&]() { @@ -277,30 +315,57 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } }(); - const auto b_grid_desc_kpad_n = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; return transform_tensor_descriptor( b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_right_pad_transform(N, NPad - N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); } - else + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) { + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), make_pass_through_transform(N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -367,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), + return math::max(a_block_space_size * sizeof(LDSTypeA) + + b_block_space_size * sizeof(LDSTypeB), c_block_size * sizeof(FloatC)); } @@ -389,6 +455,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return false; } } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || @@ -401,6 +468,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + #endif // DEBUG_LOG return false; } @@ -469,11 +555,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { #if DEBUG_LOG - std::cout - << "Arg N (" << karg.N - << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; #endif // DEBUG_LOG return false; @@ -484,25 +570,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) { #if DEBUG_LOG - std::cout - << "Arg M (" << karg.M - << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" - << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; #endif // DEBUG_LOG return false; } } - const auto num_k_loop = karg.K0 / K0PerBlock; + const auto num_k_loop = karg.K0Padded / K0PerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { #if DEBUG_LOG std::cout << "The number of k loops (" << num_k_loop << ") value is not supported by GridwiseGemm Pipeline." - << " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; #endif // DEBUG_LOG return false; } @@ -512,14 +598,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) { - const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; - const index_t KPad = KBatch * K0 * K1; + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; return KPad; } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) { - const index_t num_loop = K0 / K0PerBlock; + const index_t num_loop = K0Padded / K0PerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -577,22 +664,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 typename Block2CTileMap> __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block, - const Block2CTileMap& block_2_ctile_map) + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) { const FloatA* p_a_grid = karg.p_a_grid; const FloatB* p_b_grid = karg.p_b_grid; FloatC* p_c_grid = karg.p_c_grid; const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( - karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( - karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - const AElementwiseOperation a_element_op = AElementwiseOperation{}; - const BElementwiseOperation b_element_op = BElementwiseOperation{}; - const CElementwiseOperation c_element_op = CElementwiseOperation{}; const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); @@ -701,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + LDSTypeA, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -731,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + LDSTypeB, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -761,7 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, + LDSTypeA, + LDSTypeB, FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), @@ -770,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MRepeat, NRepeat, K1, - LoopSched>(); + LoopSched, + ComputeTypeA, + ComputeTypeB>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -778,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - ComputeType* p_a_block = static_cast(p_shared_block); - ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; + auto p_a_block = reinterpret_cast(p_shared_block); + auto p_b_block = reinterpret_cast(p_a_block + a_block_space_size); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8d7bd9a8d1eeae302e97fd465ba1678eac1b8bba..15c64f2e474272ac05f5a901e9bb8899103875d7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( @@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( @@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, + p_out_global, + batch_count, + block_2_tile_map, + compute_ptr_offset_of_batch); +#else + ignore = in_grid_desc; + ignore = p_in_global; + ignore = out_grid_desc; + ignore = p_out_global; + ignore = batch_count; + ignore = block_2_tile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct GridwiseTensorRearrange +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThisThreadBlock = ThisThreadBlock; + + __device__ static void Run(const InputGridDesc& in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc& out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap& block_2_tile_map, + const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch) + { + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); + + auto copy_global_to_global = + ThreadGroupTensorSliceTransfer_v7, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(DstInMemOp)>, + Sequence, + ThreadClusterLengths, + Sequence<0, 1>, + Sequence<0, 1>, + I1, + ScalarPerVector, + Sequence, + Sequence>{ + in_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + out_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + tensor_operation::element_wise::PassThrough{}}; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + // Global Memory + const index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + const auto in_global_buf = make_dynamic_buffer( + p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); + auto out_global_buf = make_dynamic_buffer( + p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize()); + + copy_global_to_global.Run( + tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); + } + + __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc, + const OutputGridDesc& out_grid_desc) + { + if(in_grid_desc.GetLength(I0) % MPerBlock != 0 || + in_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + if(out_grid_desc.GetLength(I0) % MPerBlock != 0 || + out_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + return true; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8a0e16d7f6e03cda523fc298070443154d97e992 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// Tensor Shape +// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1] + +// Flow: +// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size): +// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True) +// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True) +// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size +// c = -b * x_mean - db * inv_std / reduce_size +// dx = inv_std * dy * gamma + b * x + c +// return dx + +template +struct GridwiseNormalizationBwdData_mk_to_mk +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) || + (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) || + (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using GammaThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using DXThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M_K& dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_val_buf = make_dynamic_buffer( + p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto gamma_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dx_thread_buf = StaticBuffer{}; + + auto ds_thread_buf = + StaticBuffer{}; + + auto db_thread_buf = + StaticBuffer{}; + + // thread id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + ComputeDataType reduce_size = type_convert( + dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + ds_thread_buf(I) = type_convert(0.0f); + db_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K + // we don't need to use loop to read x, dy, gamma twice + if constexpr(SweepOnce) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + } // end of sweep once + else // Sweep Twice pipeline + { + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + } // end of first sweep + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + // reverse read for using dy, gamma and x in the cache + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + // move to tail + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + + // move from start to tail + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp new file mode 100644 index 0000000000000000000000000000000000000000..21248e3a0a805d81756fd3b2ab4dfe4e73a85915 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// dgamma = reduce_sum(dy * (x - mean) * inv_std) +// dbeta = reduce_sum(dy) +template +struct GridwiseNormalizationBwdGammaBeta_mk_to_k +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize, + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M& dgamma_grid_desc_m, + const GridDesc_M& dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dgamma_global_val_buf = make_dynamic_buffer( + p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize()); + + auto dbeta_global_val_buf = make_dynamic_buffer( + p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dgamma_thread_buf = + StaticBuffer{}; + + auto dbeta_thread_buf = + StaticBuffer{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dgamma_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DGammaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dgamma_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dbeta_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DBetaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dbeta_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dgamma_thread_buf(I) = type_convert(0.0f); + dbeta_thread_buf(I) = type_convert(0.0f); + }); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + dgamma_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + (x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]); + + dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k]; + }); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dgamma_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dgamma_thread_buf, + dgamma_grid_desc_m, + dgamma_global_val_buf); + + threadwise_dbeta_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbeta_thread_buf, + dbeta_grid_desc_m, + dbeta_global_val_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index c3f122106df29a471e376c453d055806c9aa2514..1bee1b93f9763efd6d5a92fddf65fbef8c62cb5c 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -18,9 +18,11 @@ template struct GridwiseNormalizationNaiveVariance_mk_to_mk { @@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk reduce::Add, true>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // LDS @@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer& var_thread_buf = mean_square_thread_buf; + StaticBuffer& + inv_std_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // E(x), E[x^2], var(x) // FIXME: Should not hack the transform from deviceOP - int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + ComputeDataType reduce_length = type_convert( + x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); @@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index e50fb9813325b08576f49d321e7a5de5d51bbff2..8157b4fbc39cc1ce9a3958ce2036e094b0d7bdcd 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -12,31 +12,42 @@ template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - ComputeDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const YElementwiseOperation y_elementwise_op) + typename GridDesc_M_K, + typename GridDesc_M> +__global__ void +kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + const GridDesc_M save_mean_grid_desc_m, + const GridDesc_M save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) { GridwiseReduction::Run(x_grid_desc_m_k, gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_block_tile_iteration, epsilon, p_x_global, p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; @@ -44,9 +55,11 @@ template auto NormalizationKernelSelector(bool isSweepOnce) { @@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, GridDesc_M_K, + GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, @@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) BetaSrcVectorSize, YDstVectorDim, YDstVectorSize, + SaveMeanInvStdDstVectorSize, false>; using GridwiseNormalizationSweepOnceNaive = GridwiseNormalizationNaiveVariance_mk_to_mk; using GridwiseNormalizationGenericWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; using GridwiseNormalizationSweepOnceWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; if constexpr(UseWelford) @@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } else { @@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } } diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp index 136ac94e7f0fabb81cfd5baea7e58174799f4d43..9e380f963801460e93684ba4a5cf3244c3fef1cb 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -17,11 +17,13 @@ template + index_t YDstVectorSize, + index_t SaveMeanInvStdDstVectorSize> struct GridwiseNormalizationSplitK2nd { static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || @@ -45,6 +48,10 @@ struct GridwiseNormalizationSplitK2nd (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -69,6 +76,10 @@ struct GridwiseNormalizationSplitK2nd static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadBufferLengths_M_1 = Sequence; static constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); @@ -99,6 +110,8 @@ struct GridwiseNormalizationSplitK2nd const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, @@ -110,6 +123,8 @@ struct GridwiseNormalizationSplitK2nd const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // Thread/Block id @@ -145,6 +160,12 @@ struct GridwiseNormalizationSplitK2nd auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + // VGPR StaticBuffer in_mean_thread_buf; @@ -158,6 +179,7 @@ struct GridwiseNormalizationSplitK2nd var_thread_buf; StaticBuffer welford_count_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; auto x_thread_buf = generate_tuple( [&](auto) { @@ -283,6 +305,42 @@ struct GridwiseNormalizationSplitK2nd thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + // step1: Merge mean and variance constexpr auto mean_var_count_thread_copy_step_I0_k = make_multi_index(I0, KThreadClusterSize); @@ -332,9 +390,33 @@ struct GridwiseNormalizationSplitK2nd BlockwiseWelford::Run( mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + + inv_std_thread_buf(I) = + type_convert(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon); }); - // step2: normalization + // step2: save mean and inverse std for backward (optional) + if(block_k_cluster_id == 0 && thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // step3: normalization constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); for(index_t k = 0; k < num_k_block_tile_iteration; ++k) @@ -360,7 +442,6 @@ struct GridwiseNormalizationSplitK2nd }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -369,7 +450,7 @@ struct GridwiseNormalizationSplitK2nd // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index ff9712276c26d14c9a68f9df618ed4666f2dc5e9..15b412fba4cc0dd3d33e6b9a4f663c94a2053bda 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -16,9 +16,11 @@ template struct GridwiseNormalizationWelfordVariance_mk_to_mk { @@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ThreadClusterLengths_M_K, ThreadClusterArrangeOrder>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer var_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto beta_global_val_buf = make_dynamic_buffer( p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto threadwise_welford = ThreadwiseWelford(); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); @@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp deleted file mode 100644 index d0d214381d14e64f91805ca2d605d78f66470c3f..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/inner_product_dpp8.hpp" -#include "ck/utility/math.hpp" - -namespace ck { - -/** - * Threadwise contraction using dot instructions with DPP8 modifier. - * - * Assumptions: - * 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1` - * are known at compile-time; - * 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time; - * 3. `TM0` is equal to 1 and `TN0` is equal to 1; - * 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by - * the size of the lane group (`dpp8::lane_group_size`). - */ -template ::type = false> -struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t TK0 = TKLengths{}[I0]; - static constexpr index_t TK1 = TKLengths{}[I1]; - static constexpr index_t TM0 = TMLengths{}[I0]; - static constexpr index_t TM1 = TMLengths{}[I1]; - static constexpr index_t TN0 = TNLengths{}[I0]; - static constexpr index_t TN1 = TNLengths{}[I1]; - - static_assert(TM0 == 1 && TN0 == 1); - - static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) || - (!ShareA && TN1 % dpp8::lane_group_size == 0)); - static constexpr index_t shared_elems_per_lane = - ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size; - - __device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() - { - static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && - BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && - CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, - "wrong!"); - } - - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; - - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1; - constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1)); - - constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane; - constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1)); - - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(0, tm1, 0, tn1)); - - constexpr int src_lane = - ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size - : (tn1 / shared_elems_per_lane) % dpp8::lane_group_size; - - dpp8::inner_product_dpp( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); - } -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 605f2569c6622bb589d46da71020ea46ddba69f7..608679a4faf0b328d70c311bcad1e21625561b2d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -137,13 +139,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); - // apply type convert - dst_vector.template AsType()(i) = type_convert(v); + dst_vector.template AsType()(i) = v; }); const bool is_dst_valid = @@ -1157,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4 src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - // apply type convert src_tmp_vector.template AsType()(i) = src_buf[Number{}]; }); } - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); + if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } }); } @@ -1289,13 +1319,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); // apply type convert - dst_buf(Number{}) = type_convert(v); + dst_buf(Number{}) = v; }); }); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 32ea8ae39c60c0e2130ab9381d94ecd8ea399982..0b2300fbe5fc3a2e15d8a66b81c98c731e45b7d2 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -9,6 +9,7 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" namespace ck { @@ -211,10 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1 auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + src_element_op_(op_r_v.template AsType()(idx), + src_vector_container.template AsType()[idx]); + }); + // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + .template SetAsType(src_data_idx_seq, + op_r_v.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -267,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // TODO make this logic more generic for more sub-dword datatype if constexpr(SrcVectorDim != DstVectorDim && - ((is_same>::value && - is_same>::value && + ((is_same>::value && SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || - (is_same>::value && - is_same>::value && + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does @@ -313,7 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - using src_vector_t = vector_type_maker_t; + using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; // get DstScalarPerVector # of read-only references to src vectors from @@ -336,17 +367,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number{}); // do data transpose - transpose_vectors{}( + transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } - - static_ford{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if needed - DstData dst_v; - src_element_op_(dst_v, src_thread_scratch_tuple_[thread_scratch_id][idx]); - dst_thread_scratch_(idx) = dst_v; - }); + else + { + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); + } #endif } @@ -761,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; using DstThreadScratch = StaticTensorTupleOfVectorBuffer{}([&](auto i) { - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_vector_container.template AsType()[i]); // apply type convert - dst_vector_container.template AsType()(i) = type_convert(v); + dst_vector_container.template AsType()(i) = v; }); const bool is_dst_valid = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..299a4b9e7ded4b737227f8f111185cfb321a9df8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,386 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7r2 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve>; + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using DstSpaceFillingCurve = SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto dst_vectors = generate_vectors(); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + dst_vectors_tuple_(iAccess) = dst_vectors; + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + + StaticallyIndexedArray dst_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a18443164822cc9ed5264eb142aa773866af0597 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_1x32x2 = 0, + dpp8_f16_2x16x2, + dpp8_f16_2x32x2, + dpp8_f16_4x16x2, + dpp8_f16_4x32x2, + dpp8_f16_8x16x2, + dpp8_f16_8x32x2, + dpp8_f16_16x16x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 1; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_1x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x32x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 979f3567e987ed66fd45af5deeae7bb2ac3fc761..814b4167b8a6e69081f03d2e6f370a6e087f9ad5 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -328,7 +328,7 @@ struct WmmaSelector } #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> - static constexpr auto GetWmma() + static constexpr auto GetWmma() { return WmmaInstr::wmma_i32_16x16x16_iu4; } diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 814969ef42baa2cc8e1f4188f61b3d50d74ea265..835075b7f2c95c3a3630cdab874cbc90157e2dac 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -31,7 +31,13 @@ enum struct MfmaInstr mfma_i32_16x16x32i8, mfma_f64_16x16x4f64, mfma_f32_32x32x16f8f8, - mfma_f32_16x16x32f8f8 + mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, + mfma_f32_32x32x16f8bf8, + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8 }; template @@ -500,10 +506,148 @@ struct mfma_type } }; -template +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; + +template struct MfmaSelector { - template + template static constexpr auto GetMfma(); template <> @@ -652,7 +796,44 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32f8f8; } - static constexpr auto selected_mfma = mfma_type()>{}; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } + + static constexpr auto selected_mfma = + mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -699,7 +880,8 @@ template + typename additional_type = base_type, + bool TransposeC = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -850,10 +1032,14 @@ struct XdlopsGemm template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "base base_type must be double, float, half, bfloat16, and int8_t!"); + static_assert( + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || + (is_same::value && is_same::value) || + (is_same::value && is_same::value), + "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) @@ -949,7 +1135,7 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index b1b203b43f381ef5779ac173c609ee893a70966b..2be0b66812434eb5127f5a97534ca4dc36b68273 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -164,6 +164,7 @@ template < index_t BK1, index_t GemmMPerBlock, index_t GemmNPerBlock, + index_t GemmKPerBlock, bool DoPadGemmM, bool DoPadGemmN> struct TransformConvBwdDataToGemm_v1 @@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - const index_t AK0 = - math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1); - if constexpr(NDimSpatial == 2) { // A: output tensor @@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmm_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmk_gemmmraw_grid_desc, - make_tuple(AK1, GemmMPerBlock), + make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), @@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1 const auto out_gemmk_gemmm_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmk_gemmmraw_grid_desc, - make_tuple(AK1, GemmMPerBlock), + make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), @@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - const index_t BK0 = - math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1); - // B weight tensor if constexpr(NDimSpatial == 2) { @@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1 const auto wei_gemmk_gemmn_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmk_gemmnraw_grid_desc, - make_tuple(BK1, GemmNPerBlock), + make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), @@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto wei_gemmk_gemm_padded_grid_desc = + const auto wei_gemmk_gemmn_padded_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmk_gemmnraw_grid_desc, - make_tuple(BK1, GemmNPerBlock), + make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemm_padded_grid_desc, - make_tuple( - make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))), + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index cee3d2825b13facc912cdc8de21774ce793e9a2c..e2f75142d4187c9581a8816f56ccb2f3608d64ab 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -20,348 +20,13 @@ struct TransformConvFwdToGemm static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Wi = a_g_n_c_wis_lengths[3]; - - const index_t Wo = c_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, @@ -473,7 +138,8 @@ struct TransformConvFwdToGemm template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, @@ -601,7 +267,8 @@ struct TransformConvFwdToGemm template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, @@ -855,22 +522,21 @@ struct TransformConvFwdToGemm // for output bias template || - is_same_v, + typename std::enable_if, bool>::type = false> - static auto - MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */) + static auto MakeCDescriptor_M_N(const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides) { - const index_t N = c_g_n_k_wos_lengths[1]; - const index_t K = c_g_n_k_wos_lengths[2]; + const index_t N = c_g_n_k_wos_lengths[1]; + const index_t K = c_g_n_k_wos_lengths[2]; + const index_t KStride = c_g_n_k_wos_strides[2]; const index_t NHoWo = N * ck::accumulate_n( c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); const auto out_gemmm_gemmn_desc = - make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(I0, KStride)); return out_gemmm_gemmn_desc; } diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 897cb4f249f415a1ca3a6660063ebc946641bd47..2ea5419d091cdb54c4e5f3d9153b4a932606d15f 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -299,584 +299,255 @@ enum struct AmdBufferCoherenceEnum GLC_SLC = 3, }; -template -__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +template +__device__ typename vector_type::type +amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { - static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(N == 1) { - // use fp32 load to mimic fp64 load - if constexpr(N == 1) - { - const float2_t tmp = - llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 2) - { - const float4_t tmp = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 4) - { - const float4_t f32_0 = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - const float4_t f32_1 = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - vector_type tmp; - - tmp.AsType()(Number<0>{}) = bit_cast(f32_0); - tmp.AsType()(Number<1>{}) = bit_cast(f32_1); - - return tmp.AsType()(Number<0>{}); - } + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - return bit_cast(tmp); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - return tmp.AsType()(Number<0>{}); - } - } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); -#else - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); - return bit_cast(tmp); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); -#else - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; - return bit_cast(tmp); -#endif - } - else if constexpr(N == 8) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); -#else - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; - return bit_cast(tmp); -#endif - } - else if constexpr(N == 16) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - static_cast(coherence)); - - tmp.AsType()(Number<2>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int8_t), - static_cast(coherence)); - - tmp.AsType()(Number<3>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int8_t), - static_cast(coherence)); - - return tmp.AsType()(Number<0>{}); -#else - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); - return bit_cast(tmp); -#endif - } + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + tmp.AsType()(Number<2>{}) = tmp2; + tmp.AsType()(Number<3>{}) = tmp3; + + return bit_cast(tmp); } } template -__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + using r_t = typename vector_type::type; + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + return bit_cast(raw_data); +} + +template +__device__ void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) { - // use fp32 store to mimic fp64 store - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - } + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { -#if 0 - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bhalf_t), - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 16) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 32) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i8(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 16) - { - llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); } template @@ -1127,31 +798,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - return bit_cast(tmp); - } - else - { - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); - } #else - if constexpr(is_same::value) - { - auto tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? bit_cast(tmp) : vector_t(0); - } - else - { - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); - return src_thread_element_valid ? tmp : vector_t(0); - } + + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); + return src_thread_element_valid ? tmp : vector_t(0); #endif } @@ -1209,34 +863,13 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - if constexpr(is_same::value) - { - auto tmp = - bit_cast::type::type>(src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); - } + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - if constexpr(is_same::value) - { - auto tmp = bit_cast::type::type>( - src_thread_data); - amd_buffer_store_impl( - tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } @@ -1311,4 +944,51 @@ amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thr #endif } +// Direct loads from global to LDS. +__device__ void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +template +__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes exactly a single DWORD. + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; + static_assert(bytes_per_thread == dword_bytes); + + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); + const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource)); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif +} + } // namespace ck diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp index 8d6c7eede9d9fe38efd40ab12d6c401c999b6d5b..a28292dade3dcad8ddcd5bd3c0cd119aeb4b4253 100644 --- a/include/ck/utility/amd_gemm_dpp.hpp +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -5,17 +5,63 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" -#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/inner_product_dpp8.hpp" namespace ck { namespace dpp8 { -/// Number of lanes that can share data using DPP8 modifiers. -constexpr index_t lane_group_size = 8; +template +struct dpp_datatypes; -__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } -__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; } // namespace dpp8 diff --git a/include/ck/utility/amd_lds.hpp b/include/ck/utility/amd_lds.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c218fded96788214688ed28fd22993dfa2f5dc66 --- /dev/null +++ b/include/ck/utility/amd_lds.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +namespace lds_utils { + +/** \brief Allocate a given number of buffers in LDS and return them as a tuple. + * + * \tparam DataType Data type of elements to be stored in LDS. + * \tparam NumBuffers Number of buffers to be allocated. + * \param lds_ptr Address of the beginning of LDS space. + * \param num_elems_per_buffer Number of elements to allocate per single buffer. + * \param start_offset_elems Number of elements to move from the start of LDS for the allocation of + * the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of + * elements. \return Tuple of dynamic buffers representing memory allocated in LDS. + */ +template +__device__ static auto AllocateLdsBuffers(void* lds_ptr, + int32_t num_elems_per_buffer, + int32_t start_offset_elems, + int32_t lds_alignment) +{ + const DataType* lds_start = static_cast(lds_ptr) + start_offset_elems; + const int32_t single_buffer_offset = + math::integer_least_multiple(num_elems_per_buffer, lds_alignment); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer(lds_start + local_offset, + num_elems_per_buffer); + }, + Number{}); +} + +} // namespace lds_utils +} // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index dd7f0b770a11470037b1d33164d830cb9819e8d1..1bb0140f3e2008c23cf589ff35e6090309e03802 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -9,6 +9,9 @@ // TODO: Add arch limitation namespace ck { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#define __gfx11__ +#endif /********************************WAVE32 MODE***********************************************/ // src: fp16, dst: fp32 @@ -25,7 +28,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> // delete them. // amd_assembly_wmma_f32_16x16x16_f16_w32( // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -46,7 +49,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -71,7 +74,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -95,7 +98,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -117,7 +120,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, @@ -145,7 +148,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); #else @@ -166,7 +169,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); @@ -191,7 +194,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); #else @@ -215,7 +218,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); @@ -237,7 +240,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#if defined(__gfx11__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ea7755036fc64b0af6aa340e062e1277c485e785..0ee52b9570f27b1a00916870e68f878009bcfc10 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,12 +1,13 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" +#pragma once namespace ck { +// Define the common macro for MI300 models +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#define __gfx94__ +#endif // fp32 template @@ -344,7 +345,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx90a__) || defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -364,7 +365,7 @@ struct intrin_mfma_f32_32x32x16f8f8<32, 32> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(reg_a), @@ -396,7 +397,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> template __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(reg_a), bit_cast(reg_b), @@ -417,5 +418,194 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> #endif } }; -} // namespace ck + +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16f8bf8; + +template <> +struct intrin_mfma_f32_32x32x16f8bf8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8bf8; + +template <> +struct intrin_mfma_f32_16x16x32f8bf8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); #endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +} // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8f896a4fca7075c4cce6f04e4685bcbe4305de58..fd882a699ec629fc9aee7ca4e6bce8eb1c3ecbdf 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -24,10 +24,9 @@ using std::byte; using bhalf_t = ushort; using half_t = _Float16; -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -using int4_t = _BitInt(4); -#endif -using f8_t = uint8_t; +using int4_t = _BitInt(4); +using f8_t = _BitInt(8); +using bf8_t = unsigned _BitInt(8); // vector_type template @@ -165,7 +164,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -// +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; + template struct vector_type { @@ -199,6 +204,7 @@ struct vector_type } }; +int static err = 0; template struct vector_type { @@ -231,6 +237,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } template @@ -246,6 +256,10 @@ struct vector_type { return data_.d2x1_; } + else + { + return err; + } } }; @@ -288,6 +302,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } template @@ -308,6 +326,10 @@ struct vector_type { return data_.d4x1_; } + else + { + return err; + } } }; @@ -357,6 +379,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } template @@ -382,6 +408,10 @@ struct vector_type { return data_.d8x1_; } + else + { + return err; + } } }; @@ -438,6 +468,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } template @@ -468,6 +502,10 @@ struct vector_type { return data_.d16x1_; } + else + { + return err; + } } }; @@ -530,6 +568,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } template @@ -564,6 +606,10 @@ struct vector_type { return data_.d32x1_; } + else + { + return err; + } } }; @@ -633,6 +679,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } template @@ -672,6 +722,10 @@ struct vector_type { return data_.d64x1_; } + else + { + return err; + } } }; @@ -747,6 +801,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } template @@ -790,6 +848,10 @@ struct vector_type { return data_.d128x1_; } + else + { + return err; + } } }; @@ -871,6 +933,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } template @@ -918,6 +984,10 @@ struct vector_type { return data_.d256x1_; } + else + { + return err; + } } }; @@ -975,6 +1045,14 @@ using f8x16_t = typename vector_type::type; using f8x32_t = typename vector_type::type; using f8x64_t = typename vector_type::type; +// bf8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; + template struct NumericLimits; @@ -1100,18 +1178,103 @@ struct NumericLimits template <> struct NumericLimits { + // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x77; // 0b01110111 - static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + + __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } - __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } - __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } +}; - __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } +template +struct NumericUtils +{ +}; - __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; }; +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode +}; } // namespace ck diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 02d61f34ed5192480d09f4432eec30e65e7264b9..76390e614e814cf40a6e9498552e116d181ab1b9 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -140,13 +140,59 @@ struct DynamicBuffer } else if constexpr(Op == InMemoryDataOperationEnum::Add) { - auto tmp = this->template Get(i, is_valid_element); - this->template Set(i, is_valid_element, x + tmp); - // tmp += x; - // this->template Set(i, is_valid_element, tmp); + auto tmp = this->template Get(i, is_valid_element); + using scalar_t = typename scalar_type>::type; + // handle bfloat addition + if constexpr(is_same_v) + { + if constexpr(is_scalar_type::value) + { + // Scalar type + auto result = + type_convert(type_convert(x) + type_convert(tmp)); + this->template Set(i, is_valid_element, result); + } + else + { + // Vector type + constexpr auto vector_size = scalar_type>::vector_size; + const vector_type a_vector{tmp}; + const vector_type b_vector{x}; + static_for<0, vector_size, 1>{}([&](auto idx) { + auto result = type_convert( + type_convert(a_vector.template AsType()[idx]) + + type_convert(b_vector.template AsType()[idx])); + this->template Set(i + idx, is_valid_element, result); + }); + } + } + else + { + this->template Set(i, is_valid_element, x + tmp); + } } } + template + __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, + index_t src_offset, + index_t dst_offset, + bool is_valid_element) const + { + // Copy data from global to LDS memory using direct loads. + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + amd_direct_load_global_to_lds(p_data_, + src_offset, + dst_buf.p_data_, + dst_offset, + is_valid_element, + element_space_size_); + } + template >::type, typename scalar_type>::type>::value, diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index ef4d16a8d09de7ee898ba6815efee98974c97a20..1960667732646d04242e168485561b40e5417ab4 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -16,59 +16,46 @@ enum class f8_rounding_mode stochastic }; +__host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +__device__ inline int clz(uint32_t x) { return __clz(x); } + } // namespace ck namespace ck::utils { namespace { -template -__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; - // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; - int exponent; + int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr uint8_t nan_code = 0x80; - constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise - typedef typename ck::conditional::value, uint16_t, uint32_t>::type - T_bitwise; + using T_bitwise = typename NumericUtils::bitwise_type; T_bitwise x_bitwise = *(reinterpret_cast(&x)); // unpack the input, depends on datatype - if constexpr(is_float) - { - head = x_bitwise & 0xFF800000; - mantissa = x_bitwise & 0x7FFFFF; - exponent = (head >> type_mant) & 0xFF; - sign = head >> (type_exp + type_mant); - } - else if constexpr(is_half) - { - head = x_bitwise & 0xFC00; - mantissa = x_bitwise & 0x3FF; - exponent = (head >> type_mant) & 0x1F; - sign = head >> (type_exp + type_mant); - } + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); + bias = NumericUtils::bias; - uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); - uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; - constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); - constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); if constexpr(negative_zero_nan) { @@ -85,39 +72,103 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) if(x_bitwise == 0) return 0; - exponent -= exp_low_cutoff - 1; - if(exponent <= 0) - drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; - mantissa += 1 << type_mant; - // apply random number if needed - mantissa += (stoch ? rng : mantissa) & drop_mask; - if(mantissa >= (2 << type_mant)) - { - mantissa >>= 1; - exponent++; + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa } - mantissa >>= (type_mant - f8_mant); - // check negative exponent - if(exponent <= 0) + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) { - if(x_bitwise == 0) - return 0; - else + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) { - // subnormal range; represented by a subnormal float8 (exponent 0) - // and involves loss of accuracy - mantissa >>= 1 - exponent; - exponent = 0; + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later } } - // above range: quantize to maximum possible float of the same sign - else if(exponent > max_exp) + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) { if(clip) { - mantissa = (1 << f8_mant) - 1; - exponent = max_exp; + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; } else { @@ -126,125 +177,117 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) } // check if x is 0.0 or -0.0 - if(exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); - mantissa &= (1 << f8_mant) - 1; - return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; + if(out_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; } -template -__host__ __device__ T run_cast_from_f8(f8_t x) +template +__host__ __device__ Y run_cast_from_f8(X x) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr uint8_t nan_code = 0x80; - T fInf, fNegInf, fNaN, fNeg0; - if constexpr(is_half) - { - constexpr uint16_t ihInf = 0x7C00; - constexpr uint16_t ihNegInf = 0xFC00; - constexpr uint16_t ihNaN = 0x7C01; - constexpr uint16_t ihNeg0 = 0x8000; - fInf = *(reinterpret_cast(&ihInf)); - fNegInf = *(reinterpret_cast(&ihNegInf)); - fNaN = *(reinterpret_cast(&ihNaN)); - fNeg0 = *(reinterpret_cast(&ihNeg0)); - } - else if constexpr(is_float) - { - constexpr uint32_t ifInf = 0x7F800000; - constexpr uint32_t ifNegInf = 0xFF800000; - constexpr uint32_t ifNaN = 0x7F800001; - constexpr uint32_t ifNeg0 = 0x80000000; - fInf = *(reinterpret_cast(&ifInf)); - fNegInf = *(reinterpret_cast(&ifNegInf)); - fNaN = *(reinterpret_cast(&ifNaN)); - fNeg0 = *(reinterpret_cast(&ifNeg0)); - } + constexpr X nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); // unpack the input - uint32_t sign = x >> (f8_exp + f8_mant); - uint32_t mantissa = x & ((1 << f8_mant) - 1); - int exponent = (x & 0x7F) >> f8_mant; + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); - typename ck::conditional::value, uint16_t, uint32_t>::type retval; + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; if constexpr(negative_zero_nan) { if(x == nan_code) - return fNaN; + return NaN; } else { if(x == nan_code) - return fNeg0; - if(exponent == ((1 << f8_exp) - 1)) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && !negative_zero_nan) + { + retval = x; + retval <<= 8; + return *(reinterpret_cast(&retval)); } // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); + int sh = 1 + clz(mantissa) - (32 - in_mant); mantissa <<= sh; - mantissa &= ((1 << f8_mant) - 1); exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); } exponent += exp_low_cutoff - 1; - mantissa <<= type_mant - f8_mant; + mantissa <<= out_mant - in_mant; // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << type_mant; + mantissa |= 1 << out_mant; mantissa >>= 1 - exponent; exponent = 0; } - retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; - return *(reinterpret_cast(&retval)); + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); } } // namespace -template -__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) { - // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "Only half and float can be casted to f8."); + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); - return run_cast_to_f8(x, rng); + return run_cast_to_f8(x, rng); } -template -__host__ __device__ T cast_from_f8(f8_t x) +template +__host__ __device__ Y cast_from_f8(X x) { // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; static_assert(is_half || is_float, "only half and float are supported."); - // check if x is 0.0 - if(x == 0) - return static_cast(0); - - return run_cast_from_f8(x); + return run_cast_from_f8(x); } } // namespace ck::utils diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b58b2b33191e23db2bda3f9d15f647f6d4f85f58..65efaf388ae528feab9fcf06818e111244fa091a 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -72,6 +72,18 @@ inner_product(const float4_t& a, const float4_t& b, f c); } +template <> +__device__ void inner_product(const bhalf_t& a, const bhalf_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { @@ -180,6 +192,8 @@ inner_product(const int8x4_t& a, const int8x4_t& b, #else c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif +#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11) + c = __builtin_amdgcn_sudot4(true, bit_cast(a), true, bit_cast(b), c, false); #else const vector_type a_vector{a}; const vector_type b_vector{b}; diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp index ccd7a4e6284fa07d7f08a8559c782ff8767d0bd6..f079e2ca6490676d6c0e83e9e74dba205a8d6583 100644 --- a/include/ck/utility/inner_product_dpp8.hpp +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once + #include "amd_gemm_dpp.hpp" #include "data_type.hpp" #include "type_convert.hpp" @@ -10,6 +11,9 @@ namespace ck { namespace dpp8 { +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + template __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a324a6c458b3f1b8bb8037ccfac76e5eadceee0 --- /dev/null +++ b/include/ck/utility/is_detected.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template