diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..30f0dedd8d26e2282d92a6a8caee7a4529915664 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# Documentation files +docs/* @saadrahim @LisaDelaney +*.md @saadrahim @LisaDelaney +*.rst @saadrahim @LisaDelaney +# Header directory +library/include/* @saadrahim @LisaDelaney diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 9cdf2d670c3584c10e1041dcbcb49da7c68cff5f..276690bd4f96f50eb6a5121d237c8c9bcb80224d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -6,7 +6,7 @@ version: 2 updates: - package-ecosystem: "pip" # See documentation for possible values - directory: "/docs/.sphinx" # Location of package manifests + directory: "/docs/sphinx" # Location of package manifests open-pull-requests-limit: 10 schedule: interval: "daily" diff --git a/.gitignore b/.gitignore index 362fb9e2ef0e031c5a1b968a2125149a4d9fd757..7af066c82dcef3cf4cc846168567e90e0a0c9db3 100644 --- a/.gitignore +++ b/.gitignore @@ -49,10 +49,10 @@ build* install.dir* # documentation artifacts -build/ _build/ _images/ _static/ _templates/ _toc.yml docBin/ +_doxygen/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6700ae05b51371a3d1ac50ba0d701e7384698a9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,14 @@ +repos: +- repo: local + hooks: + - id: clang-format + name: clang-format + entry: clang-format-12 -i --style=file + language: system + types_or: [c++, inc] + - id: copyright-year-checker + name: copyright-year-checker + entry: script/check_copyright_year.sh + verbose: false + language: script + types: [c++] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b73953683777f54db3ca26448243a236b9fb84eb..5f50df2525d2948bbabe36184f23310d47339d8e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -11,8 +11,8 @@ build: sphinx: configuration: docs/conf.py -formats: [htmlzip] +formats: [htmlzip, pdf, epub] python: install: - - requirements: docs/.sphinx/requirements.txt + - requirements: docs/sphinx/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 01883500465583097aa68642e89a89a06b6750b2..3e46a4ab4bd460095eaec86545ba212dac2b7e0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,25 +1,53 @@ -# Change Log for Composable Kernel +# Changelog for Composable Kernel Full documentation for Composable Kernel is not yet available. -## CK 0.2.0 for ROCm 5.5.0 +## (Unreleased) CK for ROCm 6.0.0 -### Fixed -- Fixed a bug in 6-dimensional kernels (#555). -- Fixed grouped ConvBwdWeight test case failure (#524). +### Fixes + - Fixed a hazard associated with inline v_dot (#808) + - Fixed two bugs in grouped convolution backward data without K padding (#848 #876) ### Optimizations -- Improve proformance of normalization kernel - -### Added -- Added support on NAVI3x. -- Added user tutorial (#563). -- Added more instances for irregular GEMM sizes (#560). -- Added inter-wave consumer-producer programming model for GEMM kernels (#310). -- Added multi-D GEMM client APIs (#534). -- Added multi-embeddings support (#542). -- Added Navi3x blockwise GEMM and real GEMM support (#541). -- Added Navi grouped ConvBwdWeight support (#505). - -### Changed -- Changed ... +None + +### Additions +- Added an image to a column kernel (#867) +- Added a column to an image kernel (#930) +- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985) +- Grouped convolution support for small K and C (#822 #879 #897) +- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) +- Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) +- Support for Batched Gemm DL (#732) + +### Changes + - Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) + +## CK 0.2.0 for ROCm 5.7.0 + +### Fixes +- Fixed a bug in 6-dimensional kernels (#555) +- Fixed a test case failure with grouped convolution backward weight (#524) + +### Optimizations +- Improved the performance of the normalization kernel + +### Additions +- New CMake flags: + - "DL_KERNELS"-- Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + - "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types + - "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler +- New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler +- Support for MI300A/MI300X +- Support for AMD RDNA 3 +- New user tutorial (#563) +- Additional instances for irregular GEMM sizes (#560) +- New inter-wave consumer-producer programming model for GEMM kernels (#310) +- GEMM with support multiple elementwise fusions (multi-D) (#534) +- Multi-embeddings support (#542) +- AMD RDNA 3 blockwise GEMM and real GEMM support (#541) +- AMD RDNA grouped convolution backward weight support (#505) +- MaxPool and AvgPool forward (#815); MaxPool backward (#750) + +### Changes +None diff --git a/CMakeLists.txt b/CMakeLists.txt index c9fb6b4552c81ef29e77dcde807f552b42916572..04674124cc6437cfc69e4cd7ea6dedb60ca93363 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,85 @@ cmake_minimum_required(VERSION 3.14) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() + +# This has to be initialized before the project() command appears +# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE +if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE ) + set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." ) +endif() + +# Default installation path +if(WIN32) + set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "") +else() + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") +endif() +set(version 1.1.0) # Check support for CUDA/HIP in Cmake -project(composable_kernel) +project(composable_kernel VERSION ${version}) list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") +if (DTYPES) + add_definitions(-DDTYPES) + if (DTYPES MATCHES "int8") + add_definitions(-DCK_ENABLE_INT8) + set(CK_ENABLE_INT8 "ON") + endif() + if (DTYPES MATCHES "fp8") + add_definitions(-DCK_ENABLE_FP8) + set(CK_ENABLE_FP8 "ON") + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") + endif() + if (DTYPES MATCHES "fp16") + add_definitions(-DCK_ENABLE_FP16) + set(CK_ENABLE_FP16 "ON") + endif() + if (DTYPES MATCHES "fp32") + add_definitions(-DCK_ENABLE_FP32) + set(CK_ENABLE_FP32 "ON") + endif() + if (DTYPES MATCHES "fp64") + add_definitions(-DCK_ENABLE_FP64) + set(CK_ENABLE_FP64 "ON") + endif() + if (DTYPES MATCHES "bf16") + add_definitions(-DCK_ENABLE_BF16) + set(CK_ENABLE_BF16 "ON") + endif() + message("DTYPES macro set to ${DTYPES}") +else() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + set(CK_ENABLE_ALL_DTYPES "ON") +endif() + +#for f8/bf8_t type +add_compile_options(-Wno-bit-int-extension) + +if(DL_KERNELS) + add_definitions(-DDL_KERNELS) + set(CK_ENABLE_DL_KERNELS "ON") +endif() + +if(INSTANCES_ONLY) + add_definitions(-DINSTANCES_ONLY) + set(CK_ENABLE_INSTANCES_ONLY "ON") +endif() + +# CK config file to record supported datatypes, etc. +configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h") + +# CK version file to record release version as well as git commit hash +find_package(Git REQUIRED) +execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) +configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h") + enable_testing() set(ROCM_SYMLINK_LIBS OFF) @@ -16,11 +91,61 @@ include(ROCMSetupVersion) include(ROCMInstallSymlinks) include(ROCMCreatePackage) include(CheckCXXCompilerFlag) - -rocm_setup_version(VERSION 0.2.0) +include(ROCMCheckTargetIds) include(TargetFlags) + +rocm_setup_version(VERSION ${version}) + list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) +message("GPU_TARGETS= ${GPU_TARGETS}") + +message("checking which targets are supported") +#This is the list of targets to be used in case GPU_TARGETS is not set on command line +#These targets will be filtered and only supported ones will be used +#Setting GPU_TARGETS on command line will override this list +if(NOT PROFILER_ONLY) + rocm_check_target_ids(DEFAULT_GPU_TARGETS + TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") +else() + add_definitions(-DPROFILER_ONLY) + set(GPU_TARGETS "" CACHE STRING "" FORCE) + if(GPU_TARGETS) + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") + endif() + if(GPU_ARCH MATCHES "gfx90") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") + elseif(GPU_ARCH MATCHES "gfx94") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942") + elseif(GPU_ARCH MATCHES "gfx10") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") + elseif(GPU_ARCH MATCHES "gfx11") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") + else() + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") + endif() + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) +endif() + +message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") + +set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + +if(GPU_TARGETS) + message("Building CK for the following targets: ${GPU_TARGETS}") +else() + message("Building CK for the following targets: ${AMDGPU_TARGETS}") +endif() +find_package(hip) +# No assumption that HIP kernels are launched with uniform block size for backward compatibility +# SWDEV-413293 and https://reviews.llvm.org/D155213 +math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") +message("hip_version_flat=${hip_VERSION_FLAT}") +if(${hip_VERSION_FLAT} GREATER 500723302) + message("Adding the fno-offload-uniform-block compiler flag") + add_compile_options(-fno-offload-uniform-block) +endif() + option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) @@ -238,18 +363,20 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) +# set CK project include directories include_directories(BEFORE + ${PROJECT_BINARY_DIR}/include ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/library/include ${HIP_INCLUDE_DIRS} ) - SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) - add_compile_options(-Werror) - add_compile_options(-Weverything) + add_compile_options(-Werror -Weverything) endif() +#add flags to reduce the size of binaries +add_compile_options(-Oz -flto=thin) message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) @@ -258,36 +385,76 @@ file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp" file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*) set(CK_DEVICE_INSTANCES) FOREACH(subdir_path ${dir_list}) - IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") - list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) - ENDIF() +set(target_dir) +IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}") + set(cmake_instance) + file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) + set(add_inst 0) + if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16") + set(add_inst 1) + endif() + if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") + set(add_inst 1) + endif() + if(NOT "${cmake_instance}" MATCHES "DTYPES") + set(add_inst 1) + endif() + if(add_inst EQUAL 1 OR NOT DEFINED DTYPES) + list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance) + endif() +ENDIF() ENDFOREACH() + add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) +add_subdirectory(library) -rocm_package_setup_component(tests +if(NOT DEFINED INSTANCES_ONLY) + if(NOT DEFINED PROFILER_ONLY) + rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name -) + ) -rocm_package_setup_component(examples + rocm_package_setup_component(examples LIBRARY_NAME composablekernel PACKAGE_NAME examples -) + ) + add_subdirectory(example) + add_subdirectory(test) -rocm_package_setup_component(profiler + rocm_package_setup_component(profiler LIBRARY_NAME composablekernel - PACKAGE_NAME ckProfiler -) - -add_subdirectory(library) -add_subdirectory(example) -add_subdirectory(test) -add_subdirectory(profiler) + PACKAGE_NAME ckprofiler + ) + add_subdirectory(profiler) + else() + #When building PROFILER_ONLY, label the package with GPU_ARCH + rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckprofiler_${GPU_ARCH} + ) + add_subdirectory(profiler) + endif() +endif() #Create an interface target for the include only files and call it "composablekernels" include(CMakePackageConfigHelpers) -set(version 1.0.0) write_basic_package_version_file( "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" VERSION "${version}" @@ -295,9 +462,9 @@ write_basic_package_version_file( ) configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in - "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" - INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel - NO_CHECK_REQUIRED_COMPONENTS_MACRO + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO ) rocm_install(FILES @@ -306,6 +473,13 @@ rocm_install(FILES DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel ) +# Install CK version and configuration files +rocm_install(FILES + ${PROJECT_BINARY_DIR}/include/ck/version.h + ${PROJECT_BINARY_DIR}/include/ck/config.h + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/ +) + set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") set(CPACK_RPM_PACKAGE_LICENSE "MIT") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8ccfe99c3cc73b643f8b92cb654005e54c0774bd..cdce5a46309f59b27d8c658e785f70bf743527db 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -4,11 +4,13 @@ This is the list of developers and contributors to Composable Kernel library ## Developers -[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2022 +[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023 -[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022 +[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2023 -[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022 +[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), [Astha Rai](https://github.com/arai713), [Shi YanXing](https://github.com/Yanxing-Shi), 2022-2023 + +[Hari Sadasivan](https://github.com/hsadasiv), [Bartlomiej Kocot](https://github.com/bartekxk), [Bartlomiej Wroblewski](https://github.com/bwroblew), 2023 Hanwen Chang, 2019-2021, diff --git a/Dockerfile b/Dockerfile index 8e6ddb1eba3dd934222fa4bd571729db6762dc69..7134e206c1738e502c493876f20aee10a7877d57 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:20.04 - -ARG ROCMVERSION=5.6 +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=5.7 ARG compiler_version="" ARG compiler_commit="" @@ -9,64 +9,72 @@ RUN set -xe ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins # Add rocm repository +RUN chmod 1777 /tmp RUN apt-get update -RUN apt-get install -y wget gnupg curl -RUN --mount=type=ssh if [ "$ROCMVERSION" != "5.6"]; then \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \ - else sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \ - apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \ - amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \ - DEBIAN_FRONTEND=noninteractive amdgpu-install -y --usecase=rocm ; \ - fi -RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" +RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl + +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg +RUN wget https://repo.radeon.com/amdgpu-install/5.7/ubuntu/focal/amdgpu-install_5.7.50700-1_all.deb --no-check-certificate +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_5.7.50700-1_all.deb + +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list' + +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" +RUN amdgpu-install -y --usecase=rocm --no-dkms + +## Sccache binary built from source for ROCm +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +RUN mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ +curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ +chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} + # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - apt-utils \ build-essential \ - ccache \ cmake \ + ccache \ git \ hip-rocclr \ + iputils-ping \ jq \ libelf-dev \ libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ + net-tools \ pkg-config \ python \ python3 \ - python-dev \ python3-dev \ python3-pip \ + redis \ sshpass \ + stunnel \ software-properties-common \ - rocm-dev \ - rocm-device-libs \ - rocm-cmake \ vim \ nano \ zlib1g-dev \ + zip \ openssh-server \ - clang-format-10 \ + clang-format-12 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* #Install latest version of cmake -RUN apt purge --auto-remove -y cmake -RUN apt update -RUN apt install -y software-properties-common lsb-release -RUN apt clean all -RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null -RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -RUN apt install -y kitware-archive-keyring -RUN rm /etc/apt/trusted.gpg.d/kitware.gpg -RUN apt install -y cmake +RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip +RUN gunzip /usr/local/bin/ninja.gz +RUN chmod a+x /usr/local/bin/ninja +RUN git clone https://github.com/nico/ninjatracing.git +# Update the cmake to the latest version +RUN pip install --upgrade cmake==3.27.5 # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer @@ -81,9 +89,9 @@ ARG PREFIX=/opt/rocm RUN pip3 install --upgrade pip RUN pip3 install sqlalchemy==1.4.46 RUN pip3 install pymysql -RUN pip3 install pandas +RUN pip3 install pandas==2.0.3 RUN pip3 install setuptools-rust -RUN pip3 install sshtunnel +RUN pip3 install sshtunnel==0.4.0 # Setup ubsan environment to printstacktrace ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -103,22 +111,24 @@ ENV compiler_commit=$compiler_commit RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" -RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi -RUN --mount=type=ssh if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ +RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi +#clean-up the deb package +RUN sh -c "rm -rf amdgpu-install*" #ENV HIP_CLANG_PATH='/llvm-project/build/bin' #RUN sh -c "echo HIP_CLANG_PATH = '$HIP_CLANG_PATH'" diff --git a/Jenkinsfile b/Jenkinsfile index e94fb1eef134eb5eefb41b5021973c4deea4c31b..6b6cf39b37af003197887b0573eddfe75d74151e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -11,6 +11,20 @@ def show_node_info() { """ } +def nthreads() { + def nproc = sh(returnStdout: true, script: 'nproc') + echo "Number of cores: ${nproc}" + def n = nproc.toInteger() + if (n > 32){ + n /= 2 + } + if (n > 64){ + n = 64 + } + echo "Number of threads used for building: ${n}" + return n +} + def runShell(String command){ def responseCode = sh returnStatus: true, script: "${command} > tmp.txt" def output = readFile(file: "tmp.txt") @@ -19,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "5.5" && params.ROCMVERSION != "5.6"){ + if (params.ROCMVERSION != "6.0"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -51,10 +65,10 @@ def getDockerImageName(){ } def check_host() { - if ("${env.CK_CCACHE}" != "null"){ - def CCACHE_SERVER="${env.CK_CCACHE.split(':')[0]}" - echo "ccache server: ${CCACHE_SERVER}" - sh '''ping -c 1 -p 6379 "${CCACHE_SERVER}" | echo $? > tmp.txt''' + if ("${env.CK_SCCACHE}" != "null"){ + def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" + echo "sccache server: ${SCCACHE_SERVER}" + sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" return (output != "0") @@ -82,24 +96,9 @@ def build_compiler(){ def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 - def prefixpath = conf.get("prefixpath", "/opt/rocm") // prefix:/opt/rocm + def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" - } if(no_cache) { dockerArgs = dockerArgs + " --no-cache " @@ -128,21 +127,6 @@ def buildDocker(install_prefix){ def image_name = getDockerImageName() echo "Building Docker for ${image_name}" def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - echo "ccache server: ${env.CK_CCACHE}" - if(env.CK_CCACHE) - { - if(check_host()) - { - echo "FOUND CCACHE SERVER: ${env.CK_CCACHE}" - } - else - { - echo "CCACHE SERVER: ${env.CK_CCACHE} NOT FOUND, got ${check_host} response" - } - dockerArgs = dockerArgs + " --build-arg CCACHE_SECONDARY_STORAGE='redis://${env.CK_CCACHE}' --build-arg COMPILER_LAUNCHER='ccache' " - env.CCACHE_DIR = """/tmp/ccache_store""" - env.CCACHE_SECONDARY_STORAGE="""redis://${env.CK_CCACHE}""" - } echo "Build Args: ${dockerArgs}" try{ @@ -155,7 +139,7 @@ def buildDocker(install_prefix){ else{ echo "Checking for image: ${image_name}" sh "docker manifest inspect --insecure ${image_name}" - echo "Image: ${image_name} found!! Skipping building image" + echo "Image: ${image_name} found! Skipping building image" } } catch(Exception ex){ @@ -196,19 +180,18 @@ def cmake_build(Map conf=[:]){ } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } + if (params.DL_KERNELS){ + setup_args = setup_args + " -DDL_KERNELS=ON " + } if(build_type_debug){ setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args }else{ setup_args = " -DCMAKE_BUILD_TYPE=release" + setup_args } - if(env.CK_CCACHE) - { - setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER='ccache' -DCMAKE_C_COMPILER_LAUNCHER='ccache' " + setup_args - } - echo "ccache server: ${env.CK_CCACHE}" def pre_setup_cmd = """ + #!/bin/bash echo \$HSA_ENABLE_SDMA ulimit -c unlimited rm -rf build @@ -217,23 +200,80 @@ def cmake_build(Map conf=[:]){ mkdir install cd build """ + def invocation_tag="" + if (setup_args.contains("gfx11")){ + invocation_tag="gfx11" + } + if (setup_args.contains("gfx10")){ + invocation_tag="gfx10" + } + if (setup_args.contains("gfx90")){ + invocation_tag="gfx90" + } + if (setup_args.contains("gfx94")){ + invocation_tag="gfx94" + } + echo "invocation tag: ${invocation_tag}" + def redis_pre_setup_cmd = pre_setup_cmd + if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { + redis_pre_setup_cmd = pre_setup_cmd + """ + #!/bin/bash + export ROCM_PATH=/opt/rocm + export SCCACHE_ENABLED=true + export SCCACHE_LOG_LEVEL=debug + export SCCACHE_IDLE_TIMEOUT=14400 + export COMPILERS_HASH_DIR=/tmp/.sccache + export SCCACHE_BIN=/usr/local/.cargo/bin/sccache + export SCCACHE_EXTRAFILES=/tmp/.sccache/rocm_compilers_hash_file + export SCCACHE_REDIS="redis://${env.CK_SCCACHE}" + echo "connect = ${env.CK_SCCACHE}" >> ../script/redis-cli.conf + export SCCACHE_C_CUSTOM_CACHE_BUSTER="${invocation_tag}" + echo \$SCCACHE_C_CUSTOM_CACHE_BUSTER + stunnel ../script/redis-cli.conf + ../script/sccache_wrapper.sh --enforce_redis + """ + try { + def cmd1 = conf.get("cmd1", """ + ${redis_pre_setup_cmd} + """) + sh cmd1 + setup_args = " -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache " + setup_args + } + catch(Exception err){ + echo "could not connect to redis server: ${err.getMessage()}. will not use sccache." + def cmd2 = conf.get("cmd2", """ + ${pre_setup_cmd} + """) + sh cmd2 + } + } + else{ + def cmd3 = conf.get("cmd3", """ + ${pre_setup_cmd} + """) + sh cmd3 + } + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") // reduce parallelism when compiling, clang uses too much memory - def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 2 )) ${config_targets}") + def nt = nthreads() + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") def execute_cmd = conf.get("execute_cmd", "") def cmd = conf.get("cmd", """ - ${pre_setup_cmd} ${setup_cmd} ${build_cmd} ${execute_cmd} """) echo cmd - sh cmd + + dir("build"){ + sh cmd + } // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "master")) { + if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } } @@ -352,8 +392,6 @@ def runCKProfiler(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //cmake_build(conf) - //instead of building, just unstash the ckProfiler and install it sh """ rm -rf build mkdir build @@ -461,7 +499,7 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') ){ + if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ navi_node = 1 } } @@ -482,7 +520,7 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') ){ + if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ navi_node = 1 } } @@ -493,10 +531,11 @@ def Build_CK(Map conf=[:]){ { cmake_build(conf) dir("build"){ + //run tests and examples + sh 'make -j check' if (navi_node == 0 ){ - //run tests and examples on all nodes except Navi - sh 'make -j check' - //we only need the ckProfiler to run the performance tests, so we pack and stash it + //we only need the ckProfiler to run the performance tests, so we pack and stash it + //do not stash profiler on Navi nodes sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' stash "ckProfiler.tar.gz" } @@ -509,6 +548,26 @@ def Build_CK(Map conf=[:]){ stash "ckprofiler_0.2.0_amd64.deb" } } + if (params.hipTensor_test && navi_node == 0 ){ + //build and test hipTensor + sh """#!/bin/bash + rm -rf "${params.hipTensor_branch}".zip + rm -rf hipTensor-"${params.hipTensor_branch}" + wget https://github.com/ROCmSoftwarePlatform/hipTensor/archive/refs/heads/"${params.hipTensor_branch}".zip + unzip -o "${params.hipTensor_branch}".zip + """ + dir("hipTensor-${params.hipTensor_branch}"){ + sh """#!/bin/bash + mkdir -p build + ls -ltr + CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="/opt/rocm;${env.WORKSPACE}/install" + cmake --build build -- -j + """ + } + dir("hipTensor-${params.hipTensor_branch}/build"){ + sh 'ctest' + } + } } } } @@ -607,8 +666,8 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.6', - description: 'Specify which ROCM version to use: 5.6 (default).') + defaultValue: '5.7', + description: 'Specify which ROCM version to use: 5.7 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -625,6 +684,22 @@ pipeline { name: "RUN_FULL_QA", defaultValue: false, description: "Select whether to run small set of performance tests (default) or full QA") + booleanParam( + name: "DL_KERNELS", + defaultValue: false, + description: "Select whether to build DL kernels (default: OFF)") + booleanParam( + name: "hipTensor_test", + defaultValue: true, + description: "Use the CK build to verify hipTensor build and tests (default: ON)") + string( + name: 'hipTensor_branch', + defaultValue: 'mainline', + description: 'Specify which branch of hipTensor to use (default: mainline)') + booleanParam( + name: "USE_SCCACHE", + defaultValue: true, + description: "Use the sccache for building CK (default: ON)") } environment{ dbuser = "${dbuser}" @@ -639,15 +714,12 @@ pipeline { } stages{ stage("Build Docker"){ - //when { - // beforeAgent true - // expression { params.BUILD_DOCKER.toBoolean() } - //} parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } steps{ buildDocker('/opt/rocm') + cleanWs() } } } @@ -665,10 +737,11 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + cleanWs() } } } @@ -678,18 +751,39 @@ pipeline { { parallel { + stage("Build CK and run Tests on MI100/MI200/MI300") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("gfx908 || gfx90a") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } stage("Build CK and run Tests on MI100/MI200") { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } agent{ label rocmnode("gfx908 || gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } - stage("Build CK and run Tests on Navi") + stage("Build CK and run Tests on Navi21") { when { beforeAgent true @@ -697,12 +791,28 @@ pipeline { } agent{ label rocmnode("navi21") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030;gfx1100;gfx1101;gfx1102" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } + stage("Build CK and run Tests on Navi32") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("navi32") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDL_KERNELS=ON -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } } @@ -725,6 +835,7 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } stage("Run ckProfiler: gfx90a") @@ -740,6 +851,7 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } } @@ -752,6 +864,7 @@ pipeline { agent { label 'mici' } steps{ process_results() + cleanWs() } } } diff --git a/LICENSE b/LICENSE index 2fe9a8455efaeda2eab474b2aa038ec2d9e76841..e03fddaf78080705d26ec277629cfb8010c077bf 100644 --- a/LICENSE +++ b/LICENSE @@ -7,7 +7,7 @@ Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) SPDX-License-Identifier: MIT -Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 04199f11bf27db5fa93ea174b94bbe987cd1a15c..e5a20f143fbcd0d326bf65e95f9daea3dcd52975 100644 --- a/README.md +++ b/README.md @@ -1,106 +1,195 @@ # Composable Kernel -## Methodology -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +The Composable Kernel (CK) library provides a programming model for writing performance-critical +kernels for machine learning workloads across multiple architectures (GPUs, CPUs, etc.). The CK library +uses general purpose kernel languages, such as HIP C++. + +CK uses two concepts to achieve performance portability and code maintainability: -CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". +* Algorithm complexity reduction for complex machine learning (ML) operators. This uses an innovative + technique called *Tensor Coordinate Transformation*. ![ALT](/docs/data/ck_component.png "CK Components") -## Code Structure -Current CK library are structured into 4 layers: -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer +The current CK library is structured into four layers: + +* Templated Tile Operators +* Templated Kernel and Invoker +* Instantiated Kernel and Invoker +* Client API ![ALT](/docs/data/ck_layer.png "CK Layers") -## Documentation +## General information -Run the steps below to build documentation locally. +To build our documentation locally, use the following code: -``` +``` bash cd docs -pip3 install -r .sphinx/requirements.txt +pip3 install -r sphinx/requirements.txt python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` -## Contributors -The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) +You can find a list of our developers and contributors on our [Contributors](/CONTRIBUTORS.md) page. +page. -## Citation -If you use CK, please use following citations: -* CK paper will be freely available on arXiv soon: [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???) +```note +If you use CK, cite us as follows: + +* [Realizing Tensor Operators Using Coordinate Transformations and Tile Based Programming](???): + This paper will be available on arXiv soon. * [CITATION.cff](/CITATION.cff) +``` -## License -CK is released under the MIT license. [License File](/LICENSE) +CK is released under the **[MIT license](/LICENSE)**. +## Building CK -# Build CK +We recommend building CK inside Docker containers, which include all necessary packages. Pre-built +Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composable_kernel/tags). -## Build docker image -```bash -DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . -``` +1. To build a new Docker image, use the Dockerfile provided with the source code: -## Launch docker -```bash -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -ck:latest \ -/bin/bash -``` + ```bash + DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile . + ``` + +2. Launch the Docker container: + + ```bash + docker run \ + -it \ + --privileged \ + --group-add sudo \ + -w /root/workspace \ + -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ + ck:latest \ + /bin/bash + ``` + +3. Clone CK source code from the GitHub repository and start the build: + + ```bash + git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git && \ + cd composable_kernel && \ + mkdir build && \ + cd build + ``` + + You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s) you want + to run CK on. You can specify single or multiple architectures. If you specify multiple architectures, + use a semicolon between each; for example, `gfx908;gfx90a;gfx940`. + + ```bash + cmake \ + -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx908;gfx90a" \ + .. + ``` + + If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets + supported by the current compiler (this may take a long time). + +4. Build the entire CK library: + + ```bash + make -j + ``` + +5. Install CK: + + ```bash + make -j install + ``` + +## Optional post-install steps + +* Build examples and tests: + + ```bash + make -j examples tests + ``` + +* Build and run all examples and tests: + + ```bash + make -j check + ``` + + You can find instructions for running each individual example in [example](/example). + +* Build ckProfiler: + + ```bash + make -j ckProfiler + ``` + + You can find instructions for running ckProfiler in [profiler](/profiler). + +Note the `-j` option for building with multiple threads in parallel. This speeds up the build significantly. +Depending on the number of CPU cores and the amount of RAM on your system, you may want to +limit the number of threads. For example, if you have a 128-core CPU and 64 Gb of RAM. + +By default, `-j` launches one thread per CPU core, which can cause the build to run out of memory and +crash. In such cases, you can reduce the number of threads to 32 by using `-j32`. + +Additional cmake flags can be used to significantly speed-up the build: + +* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library + while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a + dependency and don't plan to run any examples or tests. + +* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build + instances of select data types only. The main default data types are fp32 and fp16; you can safely skip + other data types. + +* `DL_KERNELS` (default is OFF) must be set to ON in order to build instances, such as `gemm_dl` or + `batched_gemm_multi_d_dl`. These instances are useful on architectures like the NAVI2x, as most + other platforms have faster instances, such as `xdl` or `wmma`, available. + +## Using sccache for building + +The default CK Docker images come with a pre-installed version of sccache, which supports clang +being used as hip-compiler (" -x hip"). Using sccache can help reduce the time to re-build code from +hours to 1-2 minutes. In order to invoke sccache, you need to run: -## Build CK ```bash -mkdir build && cd build - -# Need to specify target ID, example below is for gfx908 and gfx90a -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D GPU_TARGETS="gfx908;gfx90a" \ -.. + sccache --start-server ``` -### Build examples and tests +then add the following flags to the cmake command line: + ```bash - make -j examples tests - make test + -DCMAKE_CXX_COMPILER_LAUNCHER=sccache -DCMAKE_C_COMPILER_LAUNCHER=sccache ``` -Instructions for running each individual examples are under [example](/example) +You may need to clean up the build folder and repeat the cmake and make steps in order to take +advantage of the sccache during subsequent builds. + +## Using CK as pre-built kernel library +You can find instructions for using CK as a pre-built kernel library in [client_example](/client_example). + +## Contributing to CK + +When you contribute to CK, make sure you run `clang-format` on all changed files. We highly +recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run: -## Build ckProfiler ```bash - make -j ckProfiler +sudo script/install_precommit.sh ``` -Instructions for running ckProfiler are under [profiler](/profiler) -## Install CK +With this approach, `pre-commit` adds the appropriate hooks to your local repository and +automatically runs `clang-format` (and possibly additional checks) before any commit is created. + +If you need to uninstall hooks from the repository, you can do so by running the following command: + ```bash -make install +script/uninstall_precommit.sh ``` -## Using CK as pre-built kernel library -Instructions for using CK as a pre-built kernel library are under [client_example](/client_example) - -## Caveat -### Kernel Timing and Verification -CK's own kernel timer will warn up kernel once, and then run it multiple times -to get average kernel time. For some kernels that use atomic add, this will cause -output buffer to be accumulated multiple times, causing verification failure. -To work around it, do not use CK's own timer and do verification at the same time. -CK's own timer and verification in each example and ckProfiler can be enabled or -disabled from command line. +If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the +`git commit` command. diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index ba7118ba3929e3e3bcdf02e40044748860bfeebe..c37f208db1cee9bfff1fff469fd79d059fb179f0 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index b7b724ccc484d6c9ea83e791aa5b8fcde420eef7..ba2952022233b3ffd21fc245b8fd199a5ec5b096 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -11,3 +11,17 @@ target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_ope add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) + +add_custom_target(client_gemm_fastgelu_generic_examples) + +add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) +target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) +target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_operations) + +add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index 08f297f58a8d522aec8c991a8d38484ccf8e420c..756889562e84c66efb5f972621bcb61edda3af82 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2ed942f0adf64df90ea1046d1de191ce2247d7c4 --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 9) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideD1 = std::stoi(argv[7]); + StrideE = std::stoi(argv[8]); + } + else + { + printf("arg1 to 8: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_m_n_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 658c1e9e8fcbeab1ec7be4115d5a7d62c5e77283..8d2a8c234aae63a3566478b5aa9588389a247d4e 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,7 +76,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); StrideD0 = std::stoi(argv[6]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[7]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..644b428fc9f51b28a0f81ed7d79ac741ca6fbcbe --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideD0 = std::stoi(argv[6]); + StrideE = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideD0, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_m_n_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddFastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index ea269545a5cc576f7cf98ea8496acf713b6aaea6..c02df018fd35c6d37a2c6f9b9fded6390c9afb19 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -72,7 +72,7 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[4]); StrideB = std::stoi(argv[5]); - StrideE = std::stoi(argv[8]); + StrideE = std::stoi(argv[6]); } else { diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..482e93b421f7700cdf37ee014cf407ca3c63555b --- /dev/null +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = FastGelu; + +using ADataType = F16; +using BDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideE = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::FastGelu>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // get generic instance + auto& op_ptr = op_ptrs[0]; + + std::cout << "Run the generic instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + // run the generic instance + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + else + { + throw std::runtime_error( + "Generic instance should be suitable for various input lengths/strides"); + } + + std::cout << "Done" << std::endl; + + return 0; +} diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index caa6573788d201b9fda605e5388423d5966d41ec..58c91f903bc7e2f3f296b06c746b1629513bc7f2 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -172,18 +172,19 @@ int main() BLayout, CLayout>(); - const auto normalize_ptrs = - ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances< - CDataType, - ReduceDataType, - ReduceDataType, - GammaDataType, - BetaDataType, - LayerNormOutDataType>(); - std::cout << "found " << gemm_reduce_ptrs.size() << " gemm_reduceMean_reduceSquareMean instances" << std::endl; + using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise< + ck::Tuple, + ck::Tuple, + ck::tensor_operation::element_wise::Normalize, + 2>; + + const auto normalize_ptrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + NormalizeDeviceOp>::GetInstances(); + std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; auto f_matrix_space_size = diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index d4f0c2048ba81c3a1f98a22c5eb57654b0498806..3d5fb6004844af269f7786ae05de8c20cc720633 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp32.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp index 91dead41a4cac19db857b99a233839e9e6647c57..89f834b9824e134f8f0aeed8aa54f78a5c8824a3 100644 --- a/client_example/04_contraction/contraction_bilinear_fp32.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp index 9238e4cd80075a2caf47d8757a24f7cf82c4b8bd..1aa3ba7de597a0b97a295b0f7ee7ad21b1e9cd80 100644 --- a/client_example/04_contraction/contraction_bilinear_fp64.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp index 62be3377a2fb18bf388b857ecb758c0b7987871c..f8ea2258c2ba262e9db42f1b4dc92ff16cdc6286 100644 --- a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp +++ b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp32.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp index 4e08ee19cdb098b2dfb70a662d59c87008400123..ba7b0633c33aabeeb06547edfb08f506e637e599 100644 --- a/client_example/04_contraction/contraction_scale_fp32.cpp +++ b/client_example/04_contraction/contraction_scale_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp index 3c36aa21eb6c34df75f765cce894d4f137d4f080..24e52eb5aa423339ff96ad0914dc479d715fe7b7 100644 --- a/client_example/04_contraction/contraction_scale_fp64.cpp +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index 856a4cc21935f094bfd7040ea095fb04d78b7eb4..7a8e5fec9982552e4f2e4f130e5ca9f1809cf6ec 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -12,12 +12,14 @@ #include "ck/library/tensor_operation_instance/gpu/normalization.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -50,12 +52,16 @@ int main(int argc, char* argv[]) SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * M); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * M); +#endif using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -84,14 +90,21 @@ int main(int argc, char* argv[]) {0, 1}, // gammaStrides {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -100,11 +113,19 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + sizeof(BetaDataType) * N + sizeof(YDataType) * M * N; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * M * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -136,23 +157,34 @@ int main(int argc, char* argv[]) auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths {Stride, 1}, // xStrides - {1}, // gammaStrides - {1}, // betaStrides + {0, 1}, // gammaStrides + {0, 1}, // betaStrides {Stride, 1}, // yStrides + {1}, // save_mean Strides + {1}, // save_inv_std Strides {1}, // reduceDims 1e-4, x_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(), beta_device_buf.GetDeviceBuffer(), y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index e939ce8dfedb10166b15388cae1d659d33e2154a..2ccad27a88757b576a050d0558acc4108857cbd1 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -53,12 +53,35 @@ int main(int argc, char* argv[]) SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements); - using DeviceOp = ck::tensor_operation::device:: - DeviceSoftmax; + using DeviceOp = ck::tensor_operation::device::DeviceSoftmax; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); + auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths, + in_strides, + reduce_dims, + alpha, + beta, + in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + PassThrough{}, + PassThrough{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::string best_op_name; @@ -74,11 +97,6 @@ int main(int argc, char* argv[]) { auto& op_ptr = op_ptrs[i]; - if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim) - { - continue; - } - auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 9fbdb83b1cf10082958730e5115f3bf54e12d415..70be0101c6d92512ab28a72797d2b8c46fb55281 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index ece6e30c56004f58ff136502b4d8874e8269cd3a..57a210fa1f5f0d3f0eca014e08ccd119d29e4fd6 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,22 +17,22 @@ using InDataType = ck::half_t; using WeiDataType = ck::half_t; using OutDataType = ck::half_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 32; -static constexpr ck::index_t N = 256; -static constexpr ck::index_t K = 192; -static constexpr ck::index_t C = 192; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; -static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; +static constexpr ck::index_t N = 256; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W struct SimpleDeviceMem { @@ -52,50 +52,24 @@ struct SimpleDeviceMem int main() { - std::array in_lengths{G, N, Hi, Wi, C}; - std::array in_strides{0, 0, 0, 0, 1}; - - std::array wei_lengths{G, K, Y, X, C}; - std::array wei_strides{0, 0, 0, 0, 1}; - - std::array out_lengths{G, N, Ho, Wo, K}; - std::array out_strides{0, 0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNHWC/GKYXC/GNHWK to GNCHW/GKCYX/GNCHW - std::rotate( - rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 3)); - std::rotate( - rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 3)); - std::rotate( - rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 3)); - std::rotate( - rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 3)); - std::rotate( - rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 3)); - std::rotate( - rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 3)); + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; std::array filter_strides{1, 1}; std::array filter_dilations{1, 1}; std::array input_left_pads{1, 1}; std::array input_right_pads{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDRun(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * G * N * Ho * Wo * K; + sizeof(OutDataType) * N * Ho * Wo * G * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; diff --git a/client_example/08_fused_attention/fused_attention.cpp b/client_example/08_fused_attention/fused_attention.cpp index fe927da1248786a4b943f610ce38b75f0d88defd..df6bc11a70d32df221612466a8af0fbcd9cafb1c 100644 --- a/client_example/08_fused_attention/fused_attention.cpp +++ b/client_example/08_fused_attention/fused_attention.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp index 3113b7856025af74c610981db130a8d965c36a24..6c9f3bc8f6f5a3c06f339f1246b5b3985e11d2d8 100644 --- a/client_example/08_fused_attention/fused_attention_bias.cpp +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 2b7d6fc806ad48f650c3e1f72ef3190cb9f342f4..ac11aad45de84b8a2e042d8a1a4bb059b889f3f2 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) @@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable add_executable(client_gemm_quantization gemm_quantization.cpp) target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) +endif() diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index a10dd3e006a48773cec9d06ce188e16df44764d5..cd504e942e943b8f174442554eca379cd908ba6e 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,26 +17,26 @@ using BiasDataType = int32_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -55,8 +55,11 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; @@ -64,17 +67,18 @@ int main(int argc, char* argv[]) std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< NumDimSpatial, diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index b8e6a493efc9b10c2bd6705ee53ce9f87093f024..f4aa3666b1c15eccd37bdef549f8402bcd1b252b 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,19 +16,19 @@ using WeiDataType = int8_t; using BiasDataType = int32_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -55,23 +55,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD #include @@ -17,21 +17,21 @@ using BiasDataType = int32_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::TanH; using OutElementOp = ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -58,8 +58,11 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; @@ -67,17 +70,18 @@ int main(int argc, char* argv[]) std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< NumDimSpatial, diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 7637f5c7853e9fcaa226d4ffedbcd20b414cf9a3..9d60baee06fddb27451f1f51642bad738739c4bb 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,19 +16,19 @@ using WeiDataType = int8_t; using BiasDataType = int32_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::TanH; using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -56,23 +56,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD #include @@ -16,25 +16,25 @@ using WeiDataType = int8_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { @@ -54,23 +54,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD #include @@ -15,18 +15,18 @@ using InDataType = int8_t; using WeiDataType = int8_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = PassThrough; using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -53,20 +53,24 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD #include diff --git a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt b/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt deleted file mode 100644 index e564f3180d8c9295356d90266f2159de47aac060..0000000000000000000000000000000000000000 --- a/client_example/10_grouped_conv2d_bwd_data/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) -target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..60543c7308aa6ab1b27430786a77719d4cba0ad5 --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp) +target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp) +target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_operations) diff --git a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp similarity index 99% rename from client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp rename to client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index 55c789804230ccccf66d68be9244c5c4111451e6..1b2e8abc201c2aed2cd2eebccb68405a25033a43 100644 --- a/client_example/10_grouped_conv2d_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d2f2ff41bcc8f3eaa901594da9957d4cc8ff9dbb --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2330228d1d8d9f3fbc41ede8fb58da6c916856cc --- /dev/null +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 16; +static constexpr ck::index_t K = 16; +static constexpr ck::index_t C = 16; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 14; +static constexpr ck::index_t Hi = 14; +static constexpr ck::index_t Wi = 14; +static constexpr ck::index_t Do = 14; +static constexpr ck::index_t Ho = 14; +static constexpr ck::index_t Wo = 14; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + PassThrough, + ck::bf8_t, + ck::f8_t>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X; + std::size_t num_bytes = sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index a906263333c8a15947e12fe8702e431d4acb7999..4292cded2071526381db36c31a5e660f7087cbce 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -32,63 +32,49 @@ struct SimpleDeviceMem }; template -std::size_t GetFlops(ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::array& output_spatial_lengths, - const std::array& filter_spatial_lengths) +std::size_t GetFlops(const std::array& output_lengths, + const std::array& filter_lengths) { + constexpr ck::index_t spatial_offset = 3; + const auto C = filter_lengths[2]; // 2 * G * N * K * C * * - return static_cast(2) * G * N * K * C * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return static_cast(2) * C * + std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies<>()) * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + std::accumulate(std::begin(filter_lengths) + spatial_offset, + std::end(filter_lengths), static_cast(1), std::multiplies<>()); } template -std::size_t GetInputByte(ck::index_t G, - ck::index_t N, - ck::index_t C, - const std::array& input_spatial_lengths) +std::size_t GetInputByte(const std::array& input_lengths) { // sizeof(InDataType) * (G * N * C * ) + - return sizeof(InDataType) * (G * N * C * - std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), + return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths), + std::end(input_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetWeightByte(ck::index_t G, - ck::index_t K, - ck::index_t C, - const std::array& filter_spatial_lengths) +std::size_t GetWeightByte(const std::array& filter_lengths) { // sizeof(WeiDataType) * (G * K * C * ) + - return sizeof(WeiDataType) * (G * K * C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths), + std::end(filter_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetOutputByte(ck::index_t G, - ck::index_t N, - ck::index_t K, - const std::array& output_spatial_lengths) +std::size_t GetOutputByte(const std::array& output_lengths) { // sizeof(OutDataType) * (G * N * K * ); - return sizeof(OutDataType) * (G * N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies())); } @@ -101,13 +87,12 @@ template bool run_grouped_conv_bwd_weight( - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& input_lengths, + const std::array& input_strides, + const std::array& filter_lengths, + const std::array& weights_strides, + const std::array& output_lengths, + const std::array& output_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -115,9 +100,9 @@ bool run_grouped_conv_bwd_weight( { ck::index_t split_k = 2; - SimpleDeviceMem in(GetInputByte(G, N, C, input_spatial_lengths)); - SimpleDeviceMem wei(GetWeightByte(G, K, C, filter_spatial_lengths)); - SimpleDeviceMem out(GetOutputByte(G, N, K, output_spatial_lengths)); + SimpleDeviceMem in(GetInputByte(input_lengths)); + SimpleDeviceMem wei(GetWeightByte(filter_lengths)); + SimpleDeviceMem out(GetOutputByte(output_lengths)); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + // profile device operation instances std::cout << "Run all instances and do timing" << std::endl; @@ -150,13 +139,12 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -172,12 +160,10 @@ bool run_grouped_conv_bwd_weight( { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = - GetFlops(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); - std::size_t num_bytes = - GetInputByte(G, N, C, input_spatial_lengths) + - GetWeightByte(G, K, C, filter_spatial_lengths) + - GetOutputByte(G, N, K, output_spatial_lengths); + std::size_t flop = GetFlops(output_lengths, filter_lengths); + std::size_t num_bytes = GetInputByte(input_lengths) + + GetWeightByte(filter_lengths) + + GetOutputByte(output_lengths); float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -217,13 +203,12 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp index 788d50ddefb704d9c08c081ae24f40480882a04f..e6d427faf4d75033a931ec2bfeeb72f4cec90e31 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp @@ -22,6 +22,16 @@ static constexpr ck::index_t C = 192; static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; +static constexpr std::array input_lengths{G, N, C, Wi}; +static constexpr std::array filter_lengths{G, K, C, X}; +static constexpr std::array output_lengths{G, N, K, Wo}; +static constexpr std::array input_strides{N * Wi * C, Wi* C, 1, C}; +static constexpr std::array weights_strides{K * X * C, X* C, 1, C}; +static constexpr std::array output_strides{N * Wo * K, Wo* K, 1, K}; +static constexpr std::array conv_filter_strides{1}; +static constexpr std::array conv_filter_dilations{1}; +static constexpr std::array input_left_pads{1}; +static constexpr std::array input_right_pads{1}; int main() { @@ -31,7 +41,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 1903bd95b67b0834c73047cde0cdae6fe4e7fe82..4201ea61b4ee3c62f0bb8034b6e0510242000942 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -25,6 +25,19 @@ static constexpr ck::index_t Hi = 28; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 28; +static constexpr std::array input_lengths{G, N, C, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Y, X}; +static constexpr std::array output_lengths{G, N, K, Ho, Wo}; +static constexpr std::array input_strides{ + N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Y * X * C, Y* X* C, 1, X* C, C}; +static constexpr std::array output_strides{ + N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1}; +static constexpr std::array conv_filter_dilations{1, 1}; +static constexpr std::array input_left_pads{1, 1}; +static constexpr std::array input_right_pads{1, 1}; int main() { @@ -34,8 +47,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>( - G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index 2f2b5d4e2113c090c0e171a4204f3f9cfde2d276..3ae46bcd5566cd3958e1b4e666119ec55e3528e9 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; int main() { @@ -37,17 +50,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - {Di, Hi, Wi}, - {Z, Y, X}, - {Do, Ho, Wo}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index 796311d2318e7d86198fd8ea631680876028c985..2eb869f3923ad0d743d5d0061842683e2fd95f87 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; +static constexpr std::array input_strides{ + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; +static constexpr std::array output_strides{ + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; +static constexpr std::array conv_filter_strides{1, 1, 1}; +static constexpr std::array conv_filter_dilations{1, 1, 1}; +static constexpr std::array input_left_pads{1, 1, 1}; +static constexpr std::array input_right_pads{1, 1, 1}; int main() { @@ -37,17 +50,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - {Di, Hi, Wi}, - {Z, Y, X}, - {Do, Ho, Wo}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index de68f46d398958917e49bf14178f66414590ed86..bc4a6fe0bfa9e118bbd6ba32ecc7dd68f3b8b2c3 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index 8ef21986a4d9a5f1eb25e21a1073c4cc341da88d..1ed36e0f50f05483a938a9a7b1c433df003b3b39 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -191,6 +191,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index 322667a46bacae8d0c681939c3890ef9ff476b0e..f9af011c8480a5caa510a66467edc33e886fe0e2 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -187,6 +187,12 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp index 3117d162db71a0a4d02a15decac20e1f0f50d56e..5e6627ce14d113224c7b0acb4fea69cb36c1f369 100644 --- a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp index 9cfeee1cfe106e69f83dc0184f3956d6751a2947..d45782d8e0ff37027a204c6820447286581f1138 100644 --- a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp +++ b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index 28524a9eee9b87db4223484497f8e2181ed366ea..c74d7c6bd8cf9cec54168792a8718768827d7b33 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index 2275158bcb26d36d871c9e7086b45d4539584785..b45b72f0de0199daa88e9e42be320e2669398dfe 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2580a370ca480e0b9ec101e5715ac9a631a72f6..249c2c030fdc2940f8cd2b47ad70a166b635b6ae 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -1,5 +1,15 @@ -add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) -add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) +if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) + target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) -target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations) +endif() + +if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) +endif() diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a6bb5aa65be098716b504ae9d0543ed9ae129dc0..50c03c49d8c55039f63305aa24d8718972df55dd 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -94,7 +94,8 @@ template + ck::index_t NumNonSpatialDim = 3, + typename ComputeType = InDataType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -141,14 +142,10 @@ bool run_grouped_conv_fwd(std::array; + PassThrough, + ComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp index 10f914bbee3fe22cd0c3f3cfc289382f2f58ca39..d4455df628afcb6ec721b5f884b79e4572a5037c 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -11,7 +11,7 @@ using WeiDataType = ck::half_t; using OutDataType = ck::half_t; using InLayout = ck::tensor_layout::convolution::NDHWGC; -using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1651ec2f39356dd0edec755385c3716342502077 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 43c98f1e9b8ae8b89893d5e430d4f4ab89678803..7e8c98b6037418001571707a82aa14987059fb4f 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -11,7 +11,7 @@ using WeiDataType = float; using OutDataType = float; using InLayout = ck::tensor_layout::convolution::NDHWGC; -using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp index 223ed29be9ac66e8805d8d5b3e696bbac4f5c922..7ba3224fc3244ab3bd472b98864efa42ae132fde 100644 --- a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/groupnorm_swish.cpp b/client_example/18_groupnorm/groupnorm_swish.cpp index a79630c2371b169533e932bd86f422b711ffdfca..abe7492c65f4aa046abd4e5484056f099abc2173 100644 --- a/client_example/18_groupnorm/groupnorm_swish.cpp +++ b/client_example/18_groupnorm/groupnorm_swish.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -12,12 +12,14 @@ #include "ck/library/tensor_operation_instance/gpu/normalization_swish.hpp" -using XDataType = ck::half_t; -using GammaDataType = float; -using BetaDataType = float; -using YDataType = ck::half_t; -using ComputeDataType = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = float; +using BetaDataType = float; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using Swish = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD constexpr int Rank = 5; constexpr int NumReduceDim = 3; @@ -49,19 +51,24 @@ int main(int argc, char* argv[]) std::size_t xy_size = N * H * W * G * C; std::size_t gamma_beta_size = G * C; - std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; - std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector xy_strides = {H * W * G * C, W * G * C, G * C, C, 1}; + std::vector gamma_beta_strides = {0, 0, 0, C, 1}; + std::vector save_mean_inv_std_strides = {G, 1}; SimpleDeviceMem x_device_buf(sizeof(XDataType) * xy_size); SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_beta_size); SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * gamma_beta_size); SimpleDeviceMem y_device_buf(sizeof(YDataType) * xy_size); +#ifdef SAVE_MEAN_INV_STD + SimpleDeviceMem save_mean_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); + SimpleDeviceMem save_inv_std_device_buf(sizeof(SaveMeanInvStdDataType) * N * G); +#endif using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -72,6 +79,37 @@ int main(int argc, char* argv[]) std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + const auto& generic_op_ptr = op_ptrs[0]; + + auto generic_argument_ptr = + generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); + + if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get())) + { + throw std::runtime_error( + "The generic kernel instance should be able to support any input shapes"); + }; + std::string best_op_name; bool found = false; int best_op_id = -1; @@ -83,21 +121,29 @@ int main(int argc, char* argv[]) for(int i = 0; i < op_ptrs.size(); ++i) { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -105,12 +151,20 @@ int main(int argc, char* argv[]) if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t num_byte = sizeof(XDataType) * xy_size + sizeof(GammaDataType) * gamma_beta_size + sizeof(BetaDataType) * gamma_beta_size + sizeof(YDataType) * xy_size; +#ifdef SAVE_MEAN_INV_STD + num_byte += sizeof(SaveMeanInvStdDataType) * N * G * 2; +#endif + float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " @@ -131,34 +185,47 @@ int main(int argc, char* argv[]) } } - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_op_name << std::endl; - // run the best intance + if(found) { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths - xy_strides, // xStrides - gamma_beta_strides, // gammaStrides - gamma_beta_strides, // betaStrides - xy_strides, // yStrides - {1, 2, 4}, // reduceDims - 1e-6, - x_device_buf.GetDeviceBuffer(), - gamma_device_buf.GetDeviceBuffer(), - beta_device_buf.GetDeviceBuffer(), - y_device_buf.GetDeviceBuffer(), - nullptr, - nullptr, - Swish{}); + auto argument_ptr = + op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths + xy_strides, // xStrides + gamma_beta_strides, // gammaStrides + gamma_beta_strides, // betaStrides + xy_strides, // yStrides + save_mean_inv_std_strides, // save_mean Strides + save_mean_inv_std_strides, // save_inv_std Strides + {1, 2, 4}, // reduceDims + 1e-6, + x_device_buf.GetDeviceBuffer(), + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + y_device_buf.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_device_buf.GetDeviceBuffer(), + save_inv_std_device_buf.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + Swish{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); } diff --git a/client_example/19_pool/CMakeLists.txt b/client_example/19_pool/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4e2e6d4dc2d18c551e6252af631404b396943d5 --- /dev/null +++ b/client_example/19_pool/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp) +target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_operations) + +add_executable(client_max_pool2d_bwd max_pool2d_bwd.cpp) +target_link_libraries(client_max_pool2d_bwd PRIVATE composable_kernel::device_operations) + +add_executable(client_avg_pool3d_fwd avg_pool3d_fwd.cpp) +target_link_libraries(client_avg_pool3d_fwd PRIVATE composable_kernel::device_operations) + +add_executable(client_avg_pool3d_bwd avg_pool3d_bwd.cpp) +target_link_libraries(client_avg_pool3d_bwd PRIVATE composable_kernel::device_operations) diff --git a/client_example/19_pool/avg_pool3d_bwd.cpp b/client_example/19_pool/avg_pool3d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..686d1da3ad61a5cb9ca4c8bdec034591cd354c55 --- /dev/null +++ b/client_example/19_pool/avg_pool3d_bwd.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/avg_pool3d_bwd.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; + +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size) + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mMemSize_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + using DeviceOp = ck::tensor_operation::device:: + DeviceAvgPoolBwd<3, DOutDataType, DInDataType, DOutLayout, DInLayout>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(DInDataType) + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_length, + in_length, + out_tensor_stride, + in_tensor_stride, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + din_device_buf.SetZero(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db8e0569d717f08e9b64e4504d40dd23aab13129 --- /dev/null +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; +#if 0 +constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool OutputIndex = false; +#else +constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +constexpr bool OutputIndex = false; +#endif + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCDHW + std::vector in_length = {N, C, Di, Hi, Wi}; + std::vector out_length = {N, C, Do, Ho, Wo}; + std::vector window_spatial_lengths = {Z, Y, X}; + std::vector window_strides = {window_stride_d, window_stride_h, window_stride_w}; + std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + std::size_t in_tensor_size = N * C * Di * Hi * Wi; + std::size_t out_tensor_size = N * C * Do * Ho * Wo; + + // tensor layout = NDHWC + std::vector in_tensor_stride = {Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C}; + std::vector out_tensor_stride = {Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C}; + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(InDataType) + out_tensor_size * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_tensor_size * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool/max_pool2d_bwd.cpp b/client_example/19_pool/max_pool2d_bwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53ece7425f99bd6eadc6de1a6f041a12fa0951d0 --- /dev/null +++ b/client_example/19_pool/max_pool2d_bwd.cpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" +#include "ck/library/tensor_operation_instance/gpu/max_pool_bwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using IndexDataType = int32_t; + +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + SimpleDeviceMem dout_device_buf(sizeof(DOutDataType) * out_tensor_size); + SimpleDeviceMem din_device_buf(sizeof(DInDataType) * in_tensor_size); + + // Generate index data from max pool forward + { + using MaxPoolFwdDeviceOp = + ck::tensor_operation::device::DevicePoolFwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolFwdDeviceOp>::GetInstances(); + + auto& op_ptr = op_ptrs[0]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + } + + // Run MaxPool bwd + using MaxPoolBwdDeviceOp = + ck::tensor_operation::device::DeviceMaxPoolBwd; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + MaxPoolBwdDeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = in_tensor_size * sizeof(DInDataType) + + out_tensor_size * sizeof(IndexDataType) + + out_tensor_size * sizeof(DOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << "GB / s," + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + out_tensor_size, + in_tensor_size, + window_spatial_lengths, + window_strides, + window_dilations); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + + SimpleDeviceMem workspace(workspace_sz); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer()); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/19_pool/max_pool2d_fwd.cpp b/client_example/19_pool/max_pool2d_fwd.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b818a60fdec6a62ffe027285bdbec854721d5a --- /dev/null +++ b/client_example/19_pool/max_pool2d_fwd.cpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; + +// We use pool3d to implement pool2d in this example +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +constexpr ck::index_t InOutRank = 5; +constexpr ck::index_t WindowRank = 3; +#if 1 +constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +constexpr bool OutputIndex = true; +#else +constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +constexpr bool OutputIndex = false; +#endif + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +void TransformPool2dparamToPool3d(std::vector& input_lengths, + std::vector& window_lengths, + std::vector& output_lengths, + std::vector& input_stride, + std::vector& output_stride, + std::vector& indices_stride, + std::vector& window_strides, + std::vector& window_dilations, + std::vector& input_left_pads, + std::vector& input_right_pads, + std::vector& pooling_dims) +{ + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; +} + +int main(int argc, char* argv[]) +{ + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + // Pool API only support the order of NCHW + std::vector in_length = {N, C, Hi, Wi}; + std::vector out_length = {N, C, Ho, Wo}; + std::vector window_spatial_lengths = {Y, X}; + std::vector window_strides = {window_stride_h, window_stride_w}; + std::vector window_dilations = {window_dilation_h, window_dilation_w}; + std::vector input_left_pads = {in_left_pad_h, in_left_pad_w}; + std::vector input_right_pads = {in_right_pad_h, in_right_pad_w}; + std::vector pooling_dims = {2, 3}; + + std::size_t in_tensor_size = N * C * Hi * Wi; + std::size_t out_tensor_size = N * C * Ho * Wo; + + // tensor layout = NHWC + std::vector in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C}; + std::vector out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C}; + + TransformPool2dparamToPool3d(in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); + SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); + SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size); + + using DeviceOp = ck::tensor_operation::device::DevicePoolFwd; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = + in_tensor_size * sizeof(InDataType) + out_tensor_size * sizeof(OutDataType); + + if constexpr(OutputIndex) + num_bytes += out_tensor_size * sizeof(IndexDataType); + + float gb_per_sec = num_bytes / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(found) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + in_length, + window_spatial_lengths, + out_length, + in_tensor_stride, + out_tensor_stride, + out_tensor_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5571ed1d7008fd3894b746bb8472487e19e39969 --- /dev/null +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -0,0 +1,4 @@ +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) + target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +endif() diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..94a57cd029a30d0fc73bae8cb8e43c1d475c95bc --- /dev/null +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -0,0 +1,225 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = F8; +using BDataType = F16; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 8) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + + KBatch = std::stoi(argv[7]); + } + else + { + printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2abd15731fcf97af99646d5dced7aa944a974aa --- /dev/null +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations) diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c758720e10f6cd222ee68ecd5b54069984279c31 --- /dev/null +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, d0_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + d0_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + d0_dev_bufs.emplace_back(sizeof(D0DataType) * + f_matrix_space_size(Ms[i], Ns[i], 0, D0Layout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back( + {a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + std::array{d0_dev_bufs[i].GetDeviceBuffer()}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + std::array{0}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 2); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..05b9e1e29da9df4d78275b7395e5cb58c71759a2 --- /dev/null +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_operations) diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b16fe903875c10f78c3b7e2425185176db49575c --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..045fe47c4fdc80fa13721a1d219d35e26cc25dd9 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 16); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f82140f3fc3963eb1a5ee3b3aad67db69f96d92 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_im2col_col2im/CMakeLists.txt b/client_example/22_im2col_col2im/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..47ac42fe8704d62e0820ce46e62ee333cfe62c27 --- /dev/null +++ b/client_example/22_im2col_col2im/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_image_to_column image_to_column.cpp) +target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_operations) + +add_executable(client_column_to_image column_to_image.cpp) +target_link_libraries(client_column_to_image PRIVATE composable_kernel::device_operations) diff --git a/client_example/22_im2col_col2im/column_to_image.cpp b/client_example/22_im2col_col2im/column_to_image.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ebe63198fbc7d943763289157dd84cd26cece3b --- /dev/null +++ b/client_example/22_im2col_col2im/column_to_image.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Ho * Wo * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Hi * Wi * G * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/22_im2col_col2im/image_to_column.cpp b/client_example/22_im2col_col2im/image_to_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8eafbdc5bbfec61b52ed3ed6dc5d7045d6d285e9 --- /dev/null +++ b/client_example/22_im2col_col2im/image_to_column.cpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; + +using ImageLayout = ck::tensor_layout::convolution::NHWGC; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 2; +static constexpr ck::index_t N = 32; // batch size +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + + std::array in_spatial_lengths{Hi, Wi}; + std::array wei_spatial_lengths{Y, X}; + std::array out_spatial_lengths{Ho, Wo}; + + // We have NHWGC in memory space + // However, CK's API only accepts lengths and strides with order of GNCHW. + // Hence, we need to adjust the order of strides. + std::array image_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array gemm_strides{Y * X * C, G * Y * X * C, 1}; + + std::array filter_strides{1, 1}; + std::array filter_dilations{1, 1}; + std::array input_left_pads{1, 1}; + std::array input_right_pads{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C); + + using namespace ck::conv_tensor_rearrange_op; + + using DeviceOp = ck::tensor_operation::device::DeviceConvTensorRearrange; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(OutDataType) * G * N * Ho * Wo * Y * X * C; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(avg_time < best_avg_time) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + out.GetDeviceBuffer(), + G, + N, + C, + in_spatial_lengths, + out_spatial_lengths, + wei_spatial_lengths, + image_strides, + gemm_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 14c066e4a21ca145eb5ddff3e28f6bd3a45b2449..eb793b3cbd6e693fd8cf94ef3489c6ac3eb915c7 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -2,7 +2,53 @@ cmake_minimum_required(VERSION 3.15) project(ck_app) add_compile_options(-std=c++17) -find_package(composable_kernel 1.0.0 COMPONENTS device_operations) +if (DTYPES) + add_definitions(-DDTYPES) + if (DTYPES MATCHES "int8") + add_definitions(-DCK_ENABLE_INT8) + if(NOT DEFINED ${CK_ENABLE_INT8}) + set(CK_ENABLE_INT8 "ON") + endif() + endif() + if (DTYPES MATCHES "fp8") + add_definitions(-DCK_ENABLE_FP8) + if(NOT DEFINED ${CK_ENABLE_FP8}) + set(CK_ENABLE_FP8 "ON") + endif() + endif() + if (DTYPES MATCHES "fp16") + add_definitions(-DCK_ENABLE_FP16) + if(NOT DEFINED ${CK_ENABLE_FP16}) + set(CK_ENABLE_FP16 "ON") + endif() + endif() + if (DTYPES MATCHES "fp32") + add_definitions(-DCK_ENABLE_FP32) + if(NOT DEFINED ${CK_ENABLE_FP32}) + set(CK_ENABLE_FP32 "ON") + endif() + endif() + if (DTYPES MATCHES "fp64") + add_definitions(-DCK_ENABLE_FP64) + if(NOT DEFINED ${CK_ENABLE_FP64}) + set(CK_ENABLE_FP64 "ON") + endif() + endif() + if (DTYPES MATCHES "bf16") + add_definitions(-DCK_ENABLE_BF16) + if(NOT DEFINED ${CK_ENABLE_BF16}) + set(CK_ENABLE_BF16 "ON") + endif() + endif() + message("DTYPES macro set to ${DTYPES}") +else() + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES}) + set(CK_ENABLE_ALL_DTYPES "ON") + endif() +endif() + +find_package(composable_kernel COMPONENTS device_operations) find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake index 2e3669fcdf48fde87c23a84c2493ce26aecda7af..c91308b5bb6db3cac55628c2bf5123f261cc838d 100644 --- a/cmake/DoxygenDoc.cmake +++ b/cmake/DoxygenDoc.cmake @@ -309,6 +309,8 @@ XML_OUTPUT XML_PROGRAMLISTING ) +set(WARN_AS_ERROR YES) + set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") function(add_doxygen_doc) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 87bcb08e83057b4f162870706a420a2d88d468bf..87cb8cdf1180e81c7626a0083d144b3a9c38ba48 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -67,8 +67,10 @@ else() -Wunused -Wno-reserved-identifier -Werror + -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt + -Wno-unused-template ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") list(APPEND CMAKE_COMPILER_WARNINGS @@ -92,6 +94,7 @@ else() -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default + -Wno-unsafe-buffer-usage ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/docs/.sphinx/_toc.yml.in b/docs/.sphinx/_toc.yml.in deleted file mode 100644 index ff21248873756e3b9b5572516131df3ddbbc8f33..0000000000000000000000000000000000000000 --- a/docs/.sphinx/_toc.yml.in +++ /dev/null @@ -1 +0,0 @@ -root: index diff --git a/docs/.sphinx/requirements.in b/docs/.sphinx/requirements.in deleted file mode 100644 index 1905de6e6ca082affed0a2516349cfa89ded7812..0000000000000000000000000000000000000000 --- a/docs/.sphinx/requirements.in +++ /dev/null @@ -1,2 +0,0 @@ -rocm-docs-core==0.2.0 -sphinxcontrib-bibtex==2.5.0 diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt deleted file mode 100644 index d1698b2855d1ef22c6c80d0cd97f7d3b3c636ac8..0000000000000000000000000000000000000000 --- a/docs/.sphinx/requirements.txt +++ /dev/null @@ -1,283 +0,0 @@ -# -# This file is autogenerated by pip-compile with Python 3.10 -# by the following command: -# -# pip-compile .sphinx/requirements.in -# -accessible-pygments==0.0.3 - # via pydata-sphinx-theme -alabaster==0.7.13 - # via sphinx -asttokens==2.2.1 - # via stack-data -attrs==22.2.0 - # via - # jsonschema - # jupyter-cache -babel==2.12.1 - # via - # pydata-sphinx-theme - # sphinx -backcall==0.2.0 - # via ipython -beautifulsoup4==4.11.2 - # via pydata-sphinx-theme -breathe==4.34.0 - # via rocm-docs-core -certifi==2022.12.7 - # via requests -cffi==1.15.1 - # via pynacl -charset-normalizer==3.1.0 - # via requests -click==8.1.3 - # via - # jupyter-cache - # sphinx-external-toc -comm==0.1.2 - # via ipykernel -debugpy==1.6.6 - # via ipykernel -decorator==5.1.1 - # via ipython -deprecated==1.2.13 - # via pygithub -docutils==0.16 - # via - # breathe - # myst-parser - # pybtex-docutils - # pydata-sphinx-theme - # rocm-docs-core - # sphinx - # sphinxcontrib-bibtex -executing==1.2.0 - # via stack-data -fastjsonschema==2.16.3 - # via nbformat -gitdb==4.0.10 - # via gitpython -gitpython==3.1.31 - # via rocm-docs-core -greenlet==2.0.2 - # via sqlalchemy -idna==3.4 - # via requests -imagesize==1.4.1 - # via sphinx -importlib-metadata==6.0.0 - # via - # jupyter-cache - # myst-nb -ipykernel==6.21.3 - # via myst-nb -ipython==8.11.0 - # via - # ipykernel - # myst-nb -jedi==0.18.2 - # via ipython -jinja2==3.1.2 - # via - # myst-parser - # sphinx -jsonschema==4.17.3 - # via nbformat -jupyter-cache==0.5.0 - # via myst-nb -jupyter-client==8.0.3 - # via - # ipykernel - # nbclient -jupyter-core==5.3.0 - # via - # ipykernel - # jupyter-client - # nbformat -latexcodec==2.0.1 - # via pybtex -linkify-it-py==1.0.3 - # via myst-parser -markdown-it-py==2.2.0 - # via - # mdit-py-plugins - # myst-parser -markupsafe==2.1.2 - # via jinja2 -matplotlib-inline==0.1.6 - # via - # ipykernel - # ipython -mdit-py-plugins==0.3.5 - # via myst-parser -mdurl==0.1.2 - # via markdown-it-py -myst-nb==0.17.1 - # via rocm-docs-core -myst-parser[linkify]==0.18.1 - # via - # myst-nb - # rocm-docs-core -nbclient==0.5.13 - # via - # jupyter-cache - # myst-nb -nbformat==5.7.3 - # via - # jupyter-cache - # myst-nb - # nbclient -nest-asyncio==1.5.6 - # via - # ipykernel - # nbclient -packaging==23.0 - # via - # ipykernel - # pydata-sphinx-theme - # sphinx -parso==0.8.3 - # via jedi -pexpect==4.8.0 - # via ipython -pickleshare==0.7.5 - # via ipython -platformdirs==3.1.1 - # via jupyter-core -prompt-toolkit==3.0.38 - # via ipython -psutil==5.9.4 - # via ipykernel -ptyprocess==0.7.0 - # via pexpect -pure-eval==0.2.2 - # via stack-data -pybtex==0.24.0 - # via - # pybtex-docutils - # sphinxcontrib-bibtex -pybtex-docutils==1.0.2 - # via sphinxcontrib-bibtex -pycparser==2.21 - # via cffi -pydata-sphinx-theme==0.13.1 - # via sphinx-book-theme -pygithub==1.57 - # via rocm-docs-core -pygments==2.14.0 - # via - # accessible-pygments - # ipython - # pydata-sphinx-theme - # sphinx -pyjwt==2.6.0 - # via pygithub -pynacl==1.5.0 - # via pygithub -pyrsistent==0.19.3 - # via jsonschema -python-dateutil==2.8.2 - # via jupyter-client -pyyaml==6.0 - # via - # jupyter-cache - # myst-nb - # myst-parser - # pybtex - # sphinx-external-toc -pyzmq==25.0.1 - # via - # ipykernel - # jupyter-client -requests==2.28.2 - # via - # pygithub - # sphinx -rocm-docs-core==0.2.0 - # via -r .sphinx/requirements.in -six==1.16.0 - # via - # asttokens - # latexcodec - # pybtex - # python-dateutil -smmap==5.0.0 - # via gitdb -snowballstemmer==2.2.0 - # via sphinx -soupsieve==2.4 - # via beautifulsoup4 -sphinx==4.3.1 - # via - # breathe - # myst-nb - # myst-parser - # pydata-sphinx-theme - # rocm-docs-core - # sphinx-book-theme - # sphinx-copybutton - # sphinx-design - # sphinx-external-toc - # sphinx-notfound-page - # sphinxcontrib-bibtex -sphinx-book-theme==1.0.0rc2 - # via rocm-docs-core -sphinx-copybutton==0.5.1 - # via rocm-docs-core -sphinx-design==0.3.0 - # via rocm-docs-core -sphinx-external-toc==0.3.1 - # via rocm-docs-core -sphinx-notfound-page==0.8.3 - # via rocm-docs-core -sphinxcontrib-applehelp==1.0.4 - # via sphinx -sphinxcontrib-bibtex==2.5.0 - # via -r .sphinx/requirements.in -sphinxcontrib-devhelp==1.0.2 - # via sphinx -sphinxcontrib-htmlhelp==2.0.1 - # via sphinx -sphinxcontrib-jsmath==1.0.1 - # via sphinx -sphinxcontrib-qthelp==1.0.3 - # via sphinx -sphinxcontrib-serializinghtml==1.1.5 - # via sphinx -sqlalchemy==1.4.46 - # via jupyter-cache -stack-data==0.6.2 - # via ipython -tabulate==0.9.0 - # via jupyter-cache -tornado==6.2 - # via - # ipykernel - # jupyter-client -traitlets==5.9.0 - # via - # comm - # ipykernel - # ipython - # jupyter-client - # jupyter-core - # matplotlib-inline - # nbclient - # nbformat -typing-extensions==4.5.0 - # via - # myst-nb - # myst-parser -uc-micro-py==1.0.1 - # via linkify-it-py -urllib3==1.26.15 - # via requests -wcwidth==0.2.6 - # via prompt-toolkit -wrapt==1.15.0 - # via deprecated -zipp==3.15.0 - # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst index b59c6e302692cd6b9d4035ed27d5901e790a0233..f21d43c5939ab07d3e9b6cc2c6c361a782b13ff1 100644 --- a/docs/API_Reference_Guide.rst +++ b/docs/API_Reference_Guide.rst @@ -7,8 +7,8 @@ API Reference Guide Introduction ================= -This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design -principles that are used to write new classes that extend CK functionality. +This document contains details of the APIs for the Composable Kernel (CK) library and introduces +some of the key design principles that are used to write new classes that extend CK functionality. ================= Using CK API @@ -30,8 +30,8 @@ DeviceMem Kernels For Flashattention --------------------------- -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are -used in the CK GPU implementation of Flashattention. +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists +the classes that are used in the CK GPU implementation of Flashattention. **Gridwise classes** diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b2ddff398ce8efcd0c5d8be275d9dd09315d4349..41cb8f19156b389930a886be82aa5474b5465a37 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -2,7 +2,101 @@ Contributor's Guide =================== -Pull-request guidelines -======================= +This chapter explains how to get started contributing to the Composable Kernel project and what are +the contributing rules. -[TODO] +Getting started +=============== + +#. **Documentation:** Before contributing to the library, familiarize yourself with the + `Composable Kernel User Guide `_. + It provides insight into the core concepts, environment configuration, and steps to obtain or + build the library. You can also find some of this information in the + `README file `_ + on the project's GitHub page. +#. **Additional reading:** We also recommend reading a `blog post + `_ + from the AMD Community portal. It offers a deeper understanding of the library's objectives and + showcases its performance capabilities. +#. **General information:** For broader information about AMD products, consider exploring the + `AMD Developer Central portal `_. + +How do I contribute +=================== + +We deeply value contributions from our users. You can make an impact by reporting issues or +proposing code enhancements through pull requests. + +Reporting issues +---------------- + +We use `Github issues `_ +to track public bugs and enhancement requests. + +If you encounter an issue with the library, please check if the problem has already been +reported by searching existing issues on GitHub. If your issue seems unique, please submit a new +issue. All reported issues must include: + +* A comprehensive description of the problem, including: + + * What did you observe? + * Why do you think it is a bug (if it seems like one)? + * What did you expect to happen? What would indicate the resolution of the problem? + * Are there any known workarounds? + +* Your configuration details, including: + + * Which GPU are you using? + * Which OS version are you on? + * Which ROCm version are you using? + * Are you using a Docker image? If so, which one? + +* Steps to reproduce the issue, including: + + * What actions trigger the issue? What are the reproduction steps? + + * If you build the library from scratch, what CMake command did you use? + + * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? + +Before sumbitting any issue, ensure you have addressed all relevant questions from the checklist. + +Creating Pull Requests +---------------------- + +You can submit `Pull Requests (PR) on GitHub +`_. + +All contributors are required to develop their changes on a separate branch and then create a +pull requrest to merge their changes into the `develop` branch, which is the default +development branch in the Composable Kernel project. All external contributors must use their own +forks of the project to develop their changes. + +When submitting a Pull Request you should: + +* Describe the change providing information about the motivation for the change and a general + description of all code modifications. + +* Verify and test the change: + + * Run any relevant existing tests. + * Write new tests if added functionality is not covered by current tests. + +* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in + the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We + highly recommend contributors utilize this method to maintain consistent code formatting. + Instructions on setting up `pre-commit` can be found in the project's + `README file `_ + +* Link your PR to any related issues: + + * If there is an issue that is resolved by your change, please provide a link to the issue in + the description of your pull request. + +* For larger contributions, structure your change into a sequence of smaller, focused commits, each + addressing a particular aspect or fix. + +Following the above guidelines ensures a seamless review process and faster assistance from our +end. + +Thank you for your commitment to enhancing the Composable Kernel project! We look forward to collaborating with you. diff --git a/docs/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst index 4c3adf67d7119e22a622373ab8d1dccb4d024e18..3462283d90d62f1a181a1cc26d04a36f9d6f4fff 100644 --- a/docs/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -2,15 +2,16 @@ Supported Primitives Guide ========================== -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference -Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. +This document contains details of supported primitives in Composable Kernel (CK). In contrast to the +API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins +the algorithms implemented in CK. ------------ Softmax ------------ -For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated -:math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, +For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the +softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, .. math:: :nowrap: @@ -25,8 +26,8 @@ For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can d where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and :math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar. -For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size :math:`B_r \times B_c` we can -compute the row-wise softmax as follows. +For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size +:math:`B_r \times B_c` we can compute the row-wise softmax as follows. For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate, diff --git a/docs/conf.py b/docs/conf.py index 3ec81ee9df9a0dbd224ca9dbf6fbe4115c155ec5..0de590da1a53daba8dc0e379e68abbae250e0667 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -4,10 +4,21 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +import subprocess + from rocm_docs import ROCmDocs -docs_core = ROCmDocs("Composable Kernel Documentation") -docs_core.run_doxygen() + +name = "Composable Kernel" +get_version = r'sed -n -e "s/^rocm_setup_version(.* \([0-9\.]\{1,\}\).*/\1/p" ../CMakeLists.txt' +version = subprocess.getoutput(get_version) +if len(version) > 0: + name = f"{name} {version}" + +external_toc_path = "./sphinx/_toc.yml" + +docs_core = ROCmDocs(f"{name} Documentation") +docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/docBin/xml") docs_core.setup() mathjax3_config = { diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index b51226cfebe6da27988ec2685eb334f5bc900ff9..66ec91096e01b9ebfbe78559b9a168711a7b4206 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -1,27 +1,27 @@ =================== -CK docker hub +CK Docker Hub =================== -`Docker hub `_ - ------------------------------------- Why do I need this? ------------------------------------- -To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images. +To make our lives easier and bring Composable Kernel dependencies together, we recommend using +docker images that can be found on `Docker Hub `_. ------------------------------------- So what is Composable Kernel? ------------------------------------- -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +Composable Kernel (CK) library aims to provide a programming model for writing performance critical +kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, +through general purpose kernel languages, like HIP C++. To get the CK library:: git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git - run a docker container:: docker run \ @@ -30,7 +30,7 @@ run a docker container:: --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ + rocm/composable_kernel:ck_ub20.04_rocm5.6 \ /bin/bash and build the CK:: @@ -58,7 +58,9 @@ We can also run specific examples or tests like:: ./bin/example_gemm_xdl_fp16 ./bin/test_gemm_fp16 -For more details visit `CK github repo `_, `CK examples `_, `even more CK examples `_. +For more details visit `CK github repository `_, +`CK examples `_, +`even more CK examples `_. ------------------------------------- And what is inside? @@ -74,12 +76,11 @@ The docker images have everything you need for running CK including: Which image is right for me? ------------------------------------- -Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are: +Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are: -* "ck" - made for running Composable Kernel -* "ub20.04" - based on Ubuntu 20.04 -* "rocm5.4" - ROCm platform version 5.4 -* "release" - compiler version is release +* ``ck`` - made for running Composable Kernel; +* ``ub20.04`` - based on Ubuntu 20.04; +* ``rocm5.6`` - ROCm platform version 5.6. So just pick the right image for your project dependencies and you're all set. @@ -87,7 +88,9 @@ So just pick the right image for your project dependencies and you're all set. DIY starts here ------------------------------------- -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the `Dockerfile `_ for your needs. +If you need to customize a docker image or just can't stop tinkering, feel free to adjust the +`Dockerfile `_ +for your needs. ------------------------------------- License diff --git a/docs/.doxygen/Doxyfile b/docs/doxygen/Doxyfile similarity index 100% rename from docs/.doxygen/Doxyfile rename to docs/doxygen/Doxyfile diff --git a/docs/index.rst b/docs/index.rst index f4e66c1b51f4528fa47057116a8f457680b3d972..51c0c862ae3e3e8b314aa37bd63b6d762b895607 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,12 +12,15 @@ This document contains instructions for installing, using, and contributing to C Methodology ----------- -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. +Composable Kernel (CK) library aims to provide a programming model for writing performance critical +kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, +through general purpose kernel languages, like HIP C++. CK utilizes two concepts to achieve performance portability and code maintainability: * A tile-based programming model -* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". +* Algorithm complexity reduction for complex ML operators, using innovative technique we call + "Tensor Coordinate Transformation". .. image:: data/ck_component.png :alt: CK Components diff --git a/docs/license.rst b/docs/license.rst new file mode 100644 index 0000000000000000000000000000000000000000..ddb544496e41b92dcaadb6ff816946e46ca16da8 --- /dev/null +++ b/docs/license.rst @@ -0,0 +1,6 @@ +======= +License +======= + +.. include:: ../LICENSE + :literal: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in new file mode 100644 index 0000000000000000000000000000000000000000..83dd1e7b1a5d759143f2db13a1e296f5e66da96d --- /dev/null +++ b/docs/sphinx/_toc.yml.in @@ -0,0 +1,10 @@ +# Anywhere {branch} is used, the branch name will be substituted. +# These comments will also be removed. +defaults: + numbered: False + maxdepth: 6 +root: index +subtrees: + - caption: About + entries: + - file: license diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in new file mode 100644 index 0000000000000000000000000000000000000000..c4ce8be79a204c9fb8d7129f82c02d3b4c9c0d53 --- /dev/null +++ b/docs/sphinx/requirements.in @@ -0,0 +1,2 @@ +rocm-docs-core>=0.20.0 +sphinxcontrib-bibtex==2.6.1 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a83a4e273862f707db71bc88e4bc66f75746415a --- /dev/null +++ b/docs/sphinx/requirements.txt @@ -0,0 +1,159 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile requirements.in +# +accessible-pygments==0.0.3 + # via pydata-sphinx-theme +alabaster==0.7.13 + # via sphinx +babel==2.12.1 + # via + # pydata-sphinx-theme + # sphinx +beautifulsoup4==4.11.2 + # via pydata-sphinx-theme +breathe==4.34.0 + # via rocm-docs-core +certifi==2022.12.7 + # via requests +cffi==1.15.1 + # via + # cryptography + # pynacl +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via sphinx-external-toc +cryptography==40.0.2 + # via pyjwt +deprecated==1.2.13 + # via pygithub +docutils==0.16 + # via + # breathe + # myst-parser + # pybtex-docutils + # pydata-sphinx-theme + # sphinx + # sphinxcontrib-bibtex +fastjsonschema==2.18.0 + # via rocm-docs-core +gitdb==4.0.10 + # via gitpython +gitpython==3.1.35 + # via rocm-docs-core +idna==3.4 + # via requests +imagesize==1.4.1 + # via sphinx +jinja2==3.1.2 + # via + # myst-parser + # sphinx +latexcodec==2.0.1 + # via pybtex +markdown-it-py==2.2.0 + # via + # mdit-py-plugins + # myst-parser +markupsafe==2.1.2 + # via jinja2 +mdit-py-plugins==0.3.5 + # via myst-parser +mdurl==0.1.2 + # via markdown-it-py +myst-parser==1.0.0 + # via rocm-docs-core +packaging==23.0 + # via + # pydata-sphinx-theme + # sphinx +pybtex==0.24.0 + # via + # pybtex-docutils + # sphinxcontrib-bibtex +pybtex-docutils==1.0.2 + # via sphinxcontrib-bibtex +pycparser==2.21 + # via cffi +pydata-sphinx-theme==0.13.3 + # via + # rocm-docs-core + # sphinx-book-theme +pygithub==1.58.2 + # via rocm-docs-core +pygments==2.14.0 + # via + # accessible-pygments + # pydata-sphinx-theme + # sphinx +pyjwt[crypto]==2.6.0 + # via pygithub +pynacl==1.5.0 + # via pygithub +pyyaml==6.0 + # via + # myst-parser + # pybtex + # rocm-docs-core + # sphinx-external-toc +requests==2.28.2 + # via + # pygithub + # sphinx +rocm-docs-core==0.24.0 + # via -r requirements.in +six==1.16.0 + # via + # latexcodec + # pybtex +smmap==5.0.0 + # via gitdb +snowballstemmer==2.2.0 + # via sphinx +soupsieve==2.4 + # via beautifulsoup4 +sphinx==5.3.0 + # via + # breathe + # myst-parser + # pydata-sphinx-theme + # rocm-docs-core + # sphinx-book-theme + # sphinx-copybutton + # sphinx-design + # sphinx-external-toc + # sphinx-notfound-page + # sphinxcontrib-bibtex +sphinx-book-theme==1.0.1 + # via rocm-docs-core +sphinx-copybutton==0.5.1 + # via rocm-docs-core +sphinx-design==0.3.0 + # via rocm-docs-core +sphinx-external-toc==0.3.1 + # via rocm-docs-core +sphinx-notfound-page==0.8.3 + # via rocm-docs-core +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-bibtex==2.6.1 + # via -r requirements.in +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +typing-extensions==4.5.0 + # via pydata-sphinx-theme +urllib3==1.26.15 + # via requests +wrapt==1.15.0 + # via deprecated diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index b8fd094654cba671b5082d8052b7dd3d6671867a..bfb197e085a2a986ae50aea6b6eeebd8d203eab1 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -6,15 +6,26 @@ CK Hello world Motivation ------------------------------------- -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever. +This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who +would like to optimize their pipelines and squeeze every performance drop by adding Composable +Kernel (CK) library to their projects. We would like to make the CK library approachable so +the tutorial is not based on the latest release and doesn't have all the bleeding edge features, +but it will be reproducible now and forever. -During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project. +During this tutorial we will have an introduction to the CK library, we will build it and run some +examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go +in depth and breadth and get familiar with other tools and ways to integrate CK into your project. ------------------------------------- Description ------------------------------------- -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more. +Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and +efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast +and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create +new ones. The library has components required for majority of modern neural networks architectures +including matrix multiplication, convolution, contraction, reduction, attention modules, variety of +activation functions, fused operators and many more. So how do we (almost) reach the speed of light? CK acceleration abilities are based on: @@ -24,15 +35,18 @@ So how do we (almost) reach the speed of light? CK acceleration abilities are ba * Hardware acceleration use. * Support of low precision data types including fp16, bf16, int8 and int4. -If you are excited and need more technical details and benchmarking results - read this awesome `blog post `_. +If you are excited and need more technical details and benchmarking results - read this awesome +`blog post `_. -For more details visit our `github repo `_. +For more details visit our `github repository `_. ------------------------------------- Hardware targets ------------------------------------- -CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture +CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are +supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target +GPU architecture. ========== ========= GPU Target AMD GPU @@ -42,7 +56,8 @@ gfx90a Radeon Instinct MI210, MI250, MI250X gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT ========== ========= -There are also `cloud options `_ you can find if you don't have an AMD GPU at hand. +There are also `cloud options `_ you can find if +you don't have an AMD GPU at hand. ------------------------------------- Build the library @@ -54,9 +69,13 @@ First let's clone the library and rebase to the tested version:: cd composable_kernel/ git checkout tutorial_hello_world -To make our lives easier we prepared `docker images `_ with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version. +To make our lives easier we prepared +`docker images `_ with all the necessary +dependencies. Pick the right image and create a container. In this tutorial we use +``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and +ROCm v5.6. -If your current folder is ${HOME}, start the docker container with:: +If your current folder is ``${HOME}``, start the docker container with:: docker run \ -it \ @@ -64,20 +83,23 @@ If your current folder is ${HOME}, start the docker container with:: --group-add sudo \ -w /root/workspace \ -v ${HOME}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ + rocm/composable_kernel:ck_ub20.04_rocm5.6 \ /bin/bash -If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure. +If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace`` +to fit your folder structure. -Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library:: +Inside the docker container current folder is ``~/workspace``, library path is +``~/workspace/composable_kernel``, navigate to the library:: cd composable_kernel/ -Create and go to the "build" directory:: +Create and go to the ``build`` directory:: mkdir build && cd build -In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag:: +In the previous section we talked about target GPU architecture. Once you decide which one is right +for you, run CMake using the right ``GPU_TARGETS`` flag:: cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ @@ -87,7 +109,7 @@ In the previous section we talked about target GPU architecture. Once you decide -D BUILD_DEV=OFF \ -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -If everything went well the cmake run will end up with:: +If everything went well the CMake run will end up with:: -- Configuring done -- Generating done @@ -118,9 +140,12 @@ We can also run them separately, here is a separate example execution:: ./bin/example_gemm_xdl_fp16 1 1 1 -The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. +The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU, +initialize matrices with integers and benchmark the kernel execution. You can play around with +these parameters and see how output and execution results change. -If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like:: +If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see +something like:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} @@ -130,14 +155,15 @@ If everything goes well and you have a device based on gfx908 or gfx90a architec Start running 10 times... Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 -Meanwhile, running it on a gfx1030 device should result in:: +Meanwhile, running it on a `gfx1030` device should result in:: a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like:: +But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a +separate example like:: ./bin/example_gemm_dl_fp16 1 1 1 @@ -154,7 +180,14 @@ and it should result in something nice similar to:: Start running 10 times... Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> -Or we can run a separate test:: +.. note:: + + There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of + the newest versions of the library and do not see the above results when running + ``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command + in order to build the operators supported on the `gfx1030` architecture. + +We can also run a separate test:: ctest -R test_gemm_fp16 @@ -169,6 +202,9 @@ If everything goes well you should see something like:: Summary ----------- -In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task. +In this tutorial we took the first look at the Composable Kernel library, built it on your system +and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different +configs to find out the best one for your hardware and task. -P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure! +P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better +ways to spend your money for sure! diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index c5a8295188ac2e088a320fbe25a264ef8aeec2cb..42d7311d3de4bae3967765ac8b4606d6605aa3b7 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,46 +1,60 @@ add_custom_target(example_gemm_dl) add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp32) + add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) -add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_fp16) -add_dependencies(example_gemm_dl example_gemm_dl_fp32) -add_dependencies(example_gemm_dl example_gemm_dl_fp16) -add_dependencies(example_gemm_dl example_gemm_dl_int8) +add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +add_example_dependencies(example_gemm_dl example_gemm_dl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_int4) + add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) + add_example_dependencies(example_gemm_dl example_gemm_dl_int4) endif(USE_BITINT_EXTENSION_INT4) - add_custom_target(example_gemm_xdl) - add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) + +add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_custom_target(example_gemm_wmma) + add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) + add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +endif() + add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) -add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) -add_dependencies(example_gemm_xdl example_gemm_xdl_int8) -add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) +add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) + +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_int4) + add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) +add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) -endif() +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) + +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) +add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 495a8159623beb22d6002bb2e1667ef459b74139..7fd15b2833d2d7522fa3d87029493415fd596fad 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,6 +33,19 @@ struct ProblemSize final ck::index_t StrideC = 4096; }; +struct ProblemSizeStreamK final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t NumSKBlocks = -1; +}; + struct ExecutionConfig final { bool do_verification = true; @@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -inline bool -parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +template +bool parse_cmd_args(int, char*[], ProblemType&, ExecutionConfig&) +{ + return false; +} + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSize& problem_size, + ExecutionConfig& config) { if(argc == 1) { @@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi return true; } + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeStreamK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.NumSKBlocks = std::stoi(argv[10]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: NumSKBlocks(optional)" << std::endl; + return false; + } + + return true; +} diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index cf585a8c51cbb9b5b0218228e4aa706189598804..b5fecb97521bd2d4213a9435b01dafe45a43be6e 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 93f085cdee53667a5b906cfed3b037d57d00bf5f..212b72f2a6a060a602c5c640709da77684c7c55e 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index e392c490f29a48da3a6424f46876e220306f2907..e55ae140130c0779c1e81e0bdc84b6724c6bac86 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index be9e387718f120fb1ba708d374f0ff9a09fc806d..1840390aa9e02e85c88328baf358546db362ab24 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a9e3f6186cc1525c8d882e9a961ae6ec06f95b2 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 58f965be8817bb5761dd677a52966d21c1bc1024..b11fe76ab2ce031708be7addd7f84be0f15adc6b 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 9aaae6ade9564aa54bc4ea0c7f2c96aac89ff505..3cac55ef4702856dd08dc223c02ea25113dbbf32 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cc14dcb8eb48f37b7b3597a02407ef440e17ab37 --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 50d35fd9ac98de821fe4861b2dbf34907d3592eb..54fbd9cdd4983c1a810f3c8a1ac1521b363af690 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d3cf3d397a5fd329a58dc3a967aa4c8f0bb2b696 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeType = ck::half_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 99253b743d58707707c9765c18e5933cb09e4220..8361576299c3da3727cc467387e9717fceb93ce3 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d4df3fc130369dcbb71ca5eb004da3f57df7662 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b54df8ff3d350c1ef33102052589f589edbe59d5 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::bf8_t; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index 7f1283a47b36c94be9645abd9fbd63094293f46e..f6238c7aa5040d8e409a139fe8c144835282cdc9 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index e67594c5bcbd601ae4747ac7720dab66b605ba2c..cc03200b9d153e8e22e13e6c77a52ad76ed79174 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 12a69925977b7eef2e5aa178016a7d433e9c0781..4a0c23cf44c6de8fe40f1f581e6016c597ae8041 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -204,9 +204,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_xdl_streamk.cpp b/example/01_gemm/gemm_xdl_streamk.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7d433b6145017c3aa19f4ae5a9e4bf31101b1025 --- /dev/null +++ b/example/01_gemm/gemm_xdl_streamk.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +// using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +// clang-format off +using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>; + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>; + // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>; + + + +// // clang-format on +// clang-format on + +using DeviceGemmInstance = DeviceGemmStreamK; + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index 3a0ddd90b700f7086b44a8b0f4e5b9d7385d0ad2..b0f963fee51c20b089811ba2b6396b9d7125d8f6 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp" using ADataType = ck::half_t; using BDataType = ck::half_t; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4e2cedb52adc8c7030167ec7f171d91f0796f234..7be2539d903974cd676e92ddb3d692a997abf1cf 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,9 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); @@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) } }; + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); switch(config.init_method) { - case 0: break; + case 0: + ck::utils::FillConstant{static_cast(1.f)}(a_m_k); + ck::utils::FillConstant{static_cast(1.f)}(b_k_n); + break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); @@ -66,43 +99,115 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); #endif + DeviceMem workspace; auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; + using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK; + // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument( + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + if constexpr(std::is_same::value && + !std::is_base_of::value) + { + auto argument = gemm.MakeArgument( #ifdef BUILD_INT4_EXAMPLE - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), #else - static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), #endif - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + } + else if constexpr(std::is_same::value && + std::is_base_of::value) { - std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return true; + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + problem_size.NumSKBlocks); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); + if(workspace_size != 0) + { + workspace.Realloc(workspace_size); + gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer()); + } + + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + +#if 0 + // TODO!!!!! + if(workspace_size != 0){ + float * ws_ptr = reinterpret_cast(malloc(workspace_size)); + size_t ws_dwords = workspace_size / sizeof(float); + workspace.FromDevice(ws_ptr); + + for(size_t i = 0; i < ws_dwords; i++) { + uint32_t rere = reinterpret_cast(ws_ptr)[i]; + printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere); + } + free(ws_ptr); + } +#endif } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; @@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[]) return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); } + +bool run_gemm_streamk_example(int argc, char* argv[]) +{ + ProblemSizeStreamK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 16a82110276b7901df9e80fe3a8e3ccb64df22aa..d82c42d5a9d83f5c95e19f7b2f23dca7c3d2c5e9 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,4 +1,20 @@ -add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) + add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) endif() +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") + set(target 1) + endif() +endforeach() + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index ff99bf4641115e466775995de2761ba52897783d..877792d7409f48ee562cded6629f7e1bbde650d9 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f23ad26522f4a21c18db74fe5dd4468d3ff50e7 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + int alpha_; + int beta_; +}; + +template +using S = ck::Sequence; + +using I8 = std::int8_t; +using I32 = std::int32_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DDataType = I8; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 32, + 16, + 16, + 4, + 16, + 16, + 16, + 1, + 1, + S<2, 16, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 1, + S<4, 1, 8>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 16, + 2, + 1, + 1, + 1, + S<1, 16, 1, 2>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + int alpha = 1; + int beta = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index 917b6b1c3142ec8cf7d1c852ea59494fea1842a9..c3e6ef7d5df2e12de69e2f9465f8f3e3709c0859 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 35c54abac03094f24187df2503aa02b6812c20f3..2f5cba924dfd044a1ff0717d75a8b6e7efc3d16b 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1 +1,8 @@ -add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index aee51d05de58b6386a5dd267d45a7e5c12d98276..dffeff23374b6f868fcce952ef8991ca289f59b7 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index c75c5ba51e899e4d0cba6564f52a2abb63854dd8..3486b15571d3b719aab7a172d6e0a86cf371567b 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,17 +1,24 @@ -add_custom_target(example_gemm_add_add_fastgelu_xdl) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_add_add_fastgelu_xdl) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) -add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) -add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) -add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) -if(USE_BITINT_EXTENSION_INT4) - add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) -endif(USE_BITINT_EXTENSION_INT4) -add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) + set(target 1) + endif() +endforeach() diff --git a/example/04_gemm_add_add_fastgelu/common.hpp b/example/04_gemm_add_add_fastgelu/common.hpp index 839587c148956cf186e1ba5ac286a1b73797fade..91d17df95fc470ea23011999d0a929b9331d0cc2 100644 --- a/example/04_gemm_add_add_fastgelu/common.hpp +++ b/example/04_gemm_add_add_fastgelu/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index ba0476b9b9e3bc031cb0ea61eda66970c2544f39..e630f6783713517977f63a9a2a1ef20836af2c61 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp index b940bfd89737ff4f460ef2b672a99a6696d57ce3..71f6677bae95d1779c76219f6c8b5bc05c3c1d31 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp index fa651a34ea86298a60dbea6158c356b9e50fdfce..4665c3932f419c837125ea3153437302cb9dc022 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp32.cpp @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp index 9f9c423de278b8687816dac1d518f843eb2eb34a..f206bbeb411bb0de0144015d70a892b5587d6f0e 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp index fadc4ef5ee47737bb33e19f1bab3264f99de5ea2..e46483ab38ac34cb5d4f97602474887d78779327 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index e0a53005b8c2846b6c3baaee40272c9ad0858188..f9903bfe0390725997e166be78f8b0650eaf80fb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,11 +1,17 @@ -add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) -add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) -add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) -add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) -# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) + add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) + add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) + add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed + add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) + set(target 1) + endif() +endforeach() add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) - diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 4c594ccdf817a069918b348d7b23c6cc533bde28..109b8f9ee34096601f0dd485db6840e0674c1937 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp index 855710b9d9a30ec47c74a6f31d1b03ade688e28e..aeddd4fc59fc521a196487bf7c0ea4e88bf6e1c8 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp index db5a7f0bc3351c9f1670993942a54668c191e56f..564fadcbf8b10acc98e47a67cddd2cb0abe295bc 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp index 964d784c8592b2e94bc38cc8eab6b7a645d88fc1..0bd90c999233b5392f5914081196ff3b06cc8285 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_fp32.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp index b0cd88f214c8550c018ad9f4d268370dbd816c7f..9b0c8b31f172dcaf984c7c0c728a65fcb1a0554a 100644 --- a/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_dl_int8.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_dl_common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index d55d3154916b0f4ba2fee195b7957904ae3d3139..74cf91d16017bd24096dac6348e5a8815856a51a 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index d84afba6426b1abe4b10a3b4da1fefe7c2e9272a..f6d69bafd48a4b8bff968c2fb8600793b1310294 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index f5acc540cf98b8876caa0d5c0812436dd200deb8..6c3171f6157ac4a83a488af12089a97accfd2c7c 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp index 8d697976abd43e5d07c49ec4e0819cbdf15ac02b..9977a496d229876c083e0afe99e17ff0fa5fca05 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 99f7f2565c748a950841dffa12a1d343df0051ab..bf084b3cc0b6ceab326b2b21e963e0ce2d4bdcad 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" diff --git a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc index 697ada14ba960f7231db3aa0d5b2482ca38578b8..6474df1c355bc059af5429f96616b2192f46f292 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_dl_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 36a68056f1d7a34b4309a06941f0ff1f477eaba8..49852ff6678f7c3fbbfce4ac8b1303876092cf40 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 98941b4db53b022388023e5a0eba10c99b5881be..222a3b7c0be18e426860d51ec44fc758c1f22d6b 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,16 +1,25 @@ -add_custom_target(example_convnd_fwd_reduce_xdl) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_reduce_xdl) -add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) -add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) -add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) -add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) -add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) -endif(USE_BITINT_EXTENSION_INT4) + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) + + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 00e370f2968df0409155d5407066e863574613ce..137b0d1ff0fb0a7cb7159d52f54b4cd3496629ae 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp index 6ff29b4b0ff0f11d6c9d3becbd3aaafdccf3020e..4ccacb0bcee93e72eb1fd23203e9404f1c23c478 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp index 02c19c2b63bf4f9ee663140b414c89f6221cff51..bf495725e8e77a37019d1d858bf3fe3b69e828d1 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp index 679bb5c0c45a1e2e74fc56d21ee6f7f73cae76ad..5848785673c340149d622fd782671d5fb174edab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp index abdbdaf74d5c0bd2b87da17d315e8712cd896d52..bf7127502faac94e6858d37c389f316b7ead35a6 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp index cf86afa8e94957c01fc92acb5ed2286fcb52466c..3e1694cbe8cae679ef249b6654efc40f8b1fc1ab 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index b3a3891781766acd824bc1974b88f57fdd85b711..cebfeb51d63eac30243d8e3b0468b821f6fc1eb3 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md index 76d28527bb8aa521611e54d172b427aabf9fb998..bcffa684c8f2c5c34efbf4710656b0124dfa8a42 100644 --- a/example/12_reduce/README.md +++ b/example/12_reduce/README.md @@ -2,7 +2,7 @@ ## Run ```example_reduce_blockwise``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64, 7: int4) @@ -22,7 +22,7 @@ Perf: 0.238063 ms, 264.285 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSr ## Run ```example_reduce_multiblock_atomic_add``` ```bash -# -D : input 3d/4d/5d tensor lengths +# -D : input 3D/4D/5D tensor lengths # -R : reduce dimension ids # -v : verification (0=no, 1=yes) #arg1: data type (0: fp32, 1: fp64) diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index a7ee9990c1941a73b632f3cc1d32b14a00897cc2..9a736d4cfac9d1062c568e795592711e30d18586 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index e6e3cc8d52bc9fe6568d98194e644cf75ad1950a..7f8394a7301cc3da8271e1e8c41482d0b06a2859 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index dbb18a0d83f62d27750c7797842af9bea0eb1383..eb8b5c76d31a1a0fc7a1a7e085416f80be9abf7f 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_example_common.hpp b/example/12_reduce/reduce_example_common.hpp index 05f0a0edb25ba520d475522bd2963bc050ccc010..5f9a48804a7268cef5a4a065d0c875565bdb3e9e 100644 --- a/example/12_reduce/reduce_example_common.hpp +++ b/example/12_reduce/reduce_example_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/12_reduce/reduce_multiblock_atomic_add.cpp b/example/12_reduce/reduce_multiblock_atomic_add.cpp index c4d63a3add8b0d8016b11b8834a1385dcca5beb5..120e3f05957fdb9f937b290dfe09458db248f9c0 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add.cpp +++ b/example/12_reduce/reduce_multiblock_atomic_add.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp index 905242fb6b5bca3c0dd7209a5d2827546036695e..fed62186448d58ce58c2a13c6dd184e205801185 100644 --- a/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp +++ b/example/12_reduce/reduce_multiblock_atomic_add_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index db09c03321e5ac0f3408dcffbc940b6c9c0879c5..e2a923cded8fe2ceecabfaff5a3201817ae3700a 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,3 +1,2 @@ add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) - diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index b83cb6a96f0b10f0aeff5e57d386f2ac1c3fa4e4..3ce08fd2afceb083f16ff26e88caa7a0855779b0 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,115 +17,11 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" template -static void pool_host_verify(const Tensor& in, - Tensor& out, - Tensor& out_indices, - const std::array& window_spatial_lengths, - const std::array& window_strides, - const std::array& in_left_pads, - const std::array& /*in_right_pads*/) -{ - const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; - - using ReduceOperation = typename ck::reduce_binary_operator::opType; - - auto elementwise_ops = - ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); - - auto in_elementwise_op = std::get<0>(elementwise_ops); - auto acc_elementwise_op = std::get<1>(elementwise_ops); - - if constexpr(!OutputIndex) - { - using Accumulation = - ck::detail::AccumulateWithNanCheck; - - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::template GetIdentityValue(); - - for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) - { - ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; - for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) - { - ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && - wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - - in_elementwise_op(currVal, currVal); - - Accumulation::Calculate(accuVal, currVal); - } - } - } - - acc_elementwise_op(accuVal, accuVal); - - out(n, c, ho, wo) = accuVal; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOperation::template GetIdentityValue(); - IndexDataType accuIndex = 0; - - for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) - { - ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; - for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) - { - ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - IndexDataType currIndex = y * window_spatial_lengths[1] + x; - - in_elementwise_op(currVal, currVal); - - Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); - } - } - } - - acc_elementwise_op(accuVal, accuVal); - - out(n, c, ho, wo) = accuVal; - out_indices(n, c, ho, wo) = accuIndex; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - }; -} - -template ; // InSrcOutDstVectorSize - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; - - const std::array window_spatial_lengths{{Y, X}}; - const std::array window_strides{{window_stride_h, window_stride_w}}; - const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; - const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC; // InSrcOutDstVectorSize + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Y, X}; + const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector window_dilations{window_dilation_h, window_dilation_w}; + const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; // tensor layout auto f_host_tensor_descriptor = @@ -219,14 +120,17 @@ bool pool_test(bool do_verification, static_cast(in_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), static_cast(out_indices_device_buf.GetDeviceBuffer()), - N, - C, - std::array{{Hi, Wi}}, - std::array{{Y, X}}, - std::array{{Ho, Wo}}, + {N, C, Hi, Wi}, + {Y, X}, + {N, C, Ho, Wo}, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, window_strides, + window_dilations, input_left_pads, - input_right_pads); + input_right_pads, + {2, 3}); if(!pool.IsSupportedArgument(argument_ptr.get())) { @@ -245,26 +149,36 @@ bool pool_test(bool do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB / s " << std::endl; bool pass = true; if(do_verification) { - pool_host_verify(in_n_c_hi_wi, - out_n_c_ho_wo_host, - out_indices_n_c_ho_wo_host, - window_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<4, + 2, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ReduceOpId, + PropagateNan, + OutputIndex>; + + auto ref_pooling = ReferencePoolingFwdInstance{}; + auto ref_pooling_invoker = ref_pooling.MakeInvoker(); + auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + ref_pooling_invoker.Run(ref_pooling_argument); out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp index 659f3251dcf9f8d7661cf44169458937f0d15088..d767e9224893067785f1ad5301b513061d0490b0 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -1,8 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include -#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -10,9 +9,9 @@ #include "pool2d_fwd_common.hpp" -using InDataType = ck::half_t; -using OutDataType = ck::half_t; -using AccDataType = float; +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeDataType = float; using IndexDataType = int32_t; @@ -35,18 +34,20 @@ int main(int argc, char* argv[]) bool time_kernel; // Pool shape - ck::index_t N = 128; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; if(argc == 1) { @@ -60,38 +61,40 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); } - else if(argc == 16) + else if(argc == 18) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); - N = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - window_stride_h = std::stoi(argv[10]); - window_stride_w = std::stoi(argv[11]); - in_left_pad_h = std::stoi(argv[12]); - in_left_pad_w = std::stoi(argv[13]); - in_right_pad_h = std::stoi(argv[14]); - in_right_pad_w = std::stoi(argv[15]); + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + window_dilation_h = std::stoi(argv[12]); + window_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(0); } bool pass = pool_test -#include #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" @@ -10,9 +9,9 @@ #include "pool2d_fwd_common.hpp" -using InDataType = float; -using OutDataType = float; -using AccDataType = float; +using InDataType = float; +using OutDataType = float; +using ComputeDataType = float; using IndexDataType = int32_t; @@ -35,18 +34,20 @@ int main(int argc, char* argv[]) bool time_kernel; // Pool shape - ck::index_t N = 128; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; if(argc == 1) { @@ -60,38 +61,40 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); } - else if(argc == 16) + else if(argc == 18) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = static_cast(std::stoi(argv[3])); - N = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - window_stride_h = std::stoi(argv[10]); - window_stride_w = std::stoi(argv[11]); - in_left_pad_h = std::stoi(argv[12]); - in_left_pad_w = std::stoi(argv[13]); - in_right_pad_h = std::stoi(argv[14]); - in_right_pad_w = std::stoi(argv[15]); + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + window_dilation_h = std::stoi(argv[12]); + window_dilation_w = std::stoi(argv[13]); + in_left_pad_h = std::stoi(argv[14]); + in_left_pad_w = std::stoi(argv[15]); + in_right_pad_h = std::stoi(argv[16]); + in_right_pad_w = std::stoi(argv[17]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(0); } bool pass = pool_test #include diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp index d5f4e6f62c3eb8e58285e81446a5e1332c37c533..aa3e0116954e603ed5cd759820ea49184bb4cb8a 100644 --- a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp index 2371737382447d666d55633fbf3ac976c3017cf7..4b207df5c628976c54d8b9c0f53ea140d51c98e5 100644 --- a/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index e3080c4a9ae0aafd9cd93072503a6eaeeee68495..84040fcf5cb939688f7a7b03095321b15548c400 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,20 +1,32 @@ add_custom_target(example_grouped_gemm_xdl) - add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) + add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) -add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) -add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) + add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) + +add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) + +add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) + +add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) + +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp32 - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_xdl_bfp16 - example_grouped_gemm_xdl_int8 - example_grouped_gemm_multiple_d_dl_fp16) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) + add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp index a5c51ceb0cfbc1a19215c8a2313b12a3b4d5ec51..3e1f7f089371b17fb02ebca53b27b0860c7b9c72 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp similarity index 98% rename from example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp index 05d572a1f532446bdc0ec43e59b1eeadfe4bd054..680cee1f814db7cef177e87e4d9c03ba745b8e26 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a193fc39ba637cbd41df4743e0157ca40c38407b --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 128, 16, 128, 32, 8, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int k_batch = 1; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> d0_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d0_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, d0_tensors_device, + c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d0_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + d0_tensors.push_back(Tensor( + f_host_tensor_descriptor(problem_size.Ms[i], problem_size.Ns[i], 0, ELayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " d_m_n: " << d0_tensors[i].mDesc + << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(D0DataType) * d0_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + d0_tensors_device.emplace_back( + std::make_unique(sizeof(D0DataType) * problem_size.Ns[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + d0_tensors_device[i]->ToDevice(d0_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {0}}); + + grouped_gemm_kernel_args_.push_back( + {a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + std::array{d0_tensors_device[i]->GetDeviceBuffer()}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + std::array{0}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + std::vector p_Cs = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_kernel_args_dev(gemm.GetDeviceKernelArgSize(&argument)); + hip_check_error(hipMemcpy(gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + gemm.SetDeviceKernelArgs(argument, gemm_kernel_args_dev.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < problem_size.Ms[i]; ++m) + { + for(int n = 0; n < problem_size.Ns[i]; ++n) + { + cde_element_op( + c_host_tensors[i](m, n), c_host_tensors[i](m, n), d0_tensors[i](m, n)); + } + } + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ns[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95b8526094dc0ed33862d493ba34961342df550f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84abe1d1db0f3de685c053ebe67dd7bcf3807140 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 3f78dafa8977b1b57592500138dd8f034b294efd..90a12bc1ddbf27e04a55de23e3fc6845868dac43 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fd93bb5f87d49a96da8973c3f862b8b48bdeb180..28b0fcd0cea6d6aed46ffeabc35c4c5ebc72b864 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp index faf41bbf0bba8e06fee9982a2f158706401ca45d..60c4a71a35146c31d9b79703faf5d3422942d817 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp index 7cb09778c521c9dfaed38c39ee5cb7fcff7bbc26..0c96ef56d3c144f19dcb4c7101a8d481d8960ee6 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9f8f6cb1e49d7a84f8a1bb596e568131874cf63f --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256 + 256 * i); + problem_size.Ns.push_back(128 + 128 * i); + problem_size.Ks.push_back(128 + 64 * i); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 324e1772805ac2f1b9b84872d3b8baec3c39c6b5..320870e0de7cab9dee43cab6542f13ca5341a90b 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co #else a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + c_tensors_device[i]->SetZero(); #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 226656a736637beea41758bf59dd1bd27ed8189f..5955e1d6cb4589d5b71b7bd6d66f26a894423584 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,40 +1,48 @@ -add_custom_target(example_gemm_reduce_xdl) -add_custom_target(example_gemm_reduce_xdl_max) -add_custom_target(example_gemm_reduce_xdl_mean_meansquare) -add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - -add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) -add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) -add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) -add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - -add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - -add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) -add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) -add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) -add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - -add_dependencies(example_gemm_reduce_xdl_max - example_gemm_max_xdl_bf16 - example_gemm_max_xdl_fp16 - example_gemm_max_xdl_fp32 - example_gemm_max_xdl_int8) - -add_dependencies(example_gemm_reduce_xdl_mean_meansquare - example_gemm_mean_meansquare_xdl_fp16 - example_gemm_mean_meansquare_xdl_fp32 - example_gemm_mean_meansquare_xdl_bf16 - example_gemm_add_addsquare_xdl_int8) - -add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - -add_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) - -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) -endif() +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_gemm_reduce_xdl) + add_custom_target(example_gemm_reduce_xdl_max) + add_custom_target(example_gemm_reduce_xdl_mean_meansquare) + add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + + add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + endif() + set(target 1) + endif() +endforeach() diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index eb3832a668ff9829ee76e87e6ea8988b40e6980b..2f6533d4481a2f339417a19f0f290f7fd306ee9c 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp index e1248002f751abce96072ee64283ca00022a8713..b28e7f85d3137915a2fe438b16aef5447b2d4d24 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_addsquare_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index c2feffeb8953df030ad2e30cb942bec529b2cac1..b30ce2c48ad32d99dc5c8a4aaf7a3ee9ea774baf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 363390add3e99c96ef3891bec995cda5f8c3e24b..31e2efd6f635e30e569d97f73b4a5967b4909b82 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index de6b7eb480b30d4347018da94ec518ff29fdbeec..d3c7c1d99c06a222bcefae42d94cdbb6fc8e1239 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp index 9666fc6622cdf3ab7cbab7071ebffd69d18b7668..9a4a6bc6e11a657b9d1eeaa35dd92552e397130d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp index 00e0b767a45dd979868b81635a79600d4695c779..1a8457a8bf87cdc82e687246014d2cba73065b8f 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index 652c0e6ea6d2bbe3321de2deefb2c186d7252023..5c2706c79ace02e72f2956270c322c4b00782ebf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index 7eee24fed83988698e0e6d46f830f3a9dfdb0257..c119e243702a27df9bba2d6cf3f791f523c98d7d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index c250b996928dd406bde9253ccbf40e6cd9d77347..0f5e588383abf30064d27e452131ceef1a7ad828 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 62992de59765d3e05d935ac69894c4550a7ef472..1bea1bcf3e0fc1494e07952020db175a1d91bfbf 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index fa4e65d965f16c7049a21dba93e881fa04cd5319..7c6d10d8a0d0375dd5bb8ba498b81a08d91bd7cd 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,16 @@ -add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) -target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + endif() + set(target 1) + endif() +endforeach() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) -target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) +endif() diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp index 26fa9e9821fc050c9ab69fe1a170b4c72a98dcd3..4a9d16c5c303e43989c1f32e51c2cbce6f279e5d 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -81,32 +81,33 @@ int run_conv_bwd_data(bool do_verification, in_device_buf.SetZero(); // do GEMM - auto conv = DeviceConvNdBwdDataInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.N_, - conv_param.K_, - conv_param.C_, - conv_param.input_spatial_lengths_, - conv_param.filter_spatial_lengths_, - conv_param.GetOutputSpatialLengths(), - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op); - - if(!conv.IsSupportedArgument(argument)) + auto conv = DeviceConvNdBwdDataInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param.N_, + conv_param.K_, + conv_param.C_, + conv_param.input_spatial_lengths_, + conv_param.filter_spatial_lengths_, + conv_param.GetOutputSpatialLengths(), + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument.get())) { std::cout << "Not support,please check parameters or device"; return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = conv_param.GetFlops(); std::size_t num_btype = conv_param.GetByte(); diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp index f0896e977144867421b857afd0a58a239716712b..6b84eaba471a966c89b62d9f3be32d9b0e0834d6 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp index c4f2c1f02bb9010ab3f5d97166819732e911b68a..c9989c60ac2c74e6821bdf6dd40e703998accf01 100644 --- a/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp +++ b/example/17_convnd_bwd_data/convnd_bwd_data_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_bwd_data_common.hpp" diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index 99fc0043d2803494522b446cabf4a3f98ca12a5a..94ed129dc03664043b1a0295f9f1dcfb38a77002 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,2 +1,8 @@ -add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) - +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index c2e3602a7bb22e3f31c66bb6045171229584fa14..e363dc5c12dd84a36a67e969b06d29a179679f94 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp index bee5dea546f6da3cc23aa539583915c9b8a3b21e..24c8d82f674d97afae8cb1b3dc0274c1c13daf35 100644 --- a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp index 6fc63b899e5c6abc91cad26c92c7d7733f0e8f5a..3c04c561403d0ecfb0819aa569bdf24511053442 100644 --- a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp index a5a6bc0a8bedee0a16906381c91a9c93b070e572..1ac09641a1e20a6be8bf44f4f1ff0f3be045dedb 100644 --- a/example/19_binary_elementwise/elementwise_add_1d.cpp +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp index cc209b12e3df12413beabfd11961671f4e303b41..e571aa8468008a5f4e7097eee74763a583faf9af 100644 --- a/example/19_binary_elementwise/elementwise_add_4d.cpp +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index cbe4f5f4869d77c7b176980193cd03403c3d5c1b..c28fca6fa296273fb2cbd76ab9b9c58a2832e050 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,14 +1,29 @@ -add_custom_target(example_grouped_conv_bwd_weight) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) -add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) + set(target 1) + endif() -add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16 - example_grouped_conv_bwd_weight_xdl_bf16) + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_weight) + add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) + set(target 1) + endif() +endforeach() add_custom_target(example_grouped_conv_bwd_weight_dl) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - -add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index 3f4818d2e3336799c3dfb5cca4fb14d9428ab58a..1d2206ff72a7f64b23c79bb4874ace852a1c16c3 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -23,6 +23,12 @@ using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +#ifdef CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif template using S = ck::Sequence; @@ -40,25 +46,21 @@ struct CommonLayoutSetting using OutputLayout = OutputLay; }; -template -struct CommonLayoutSettingSelector; - namespace ctl = ck::tensor_layout::convolution; - -template <> -struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<2> final - : CommonLayoutSetting -{ -}; - -template <> -struct CommonLayoutSettingSelector<3> final - : CommonLayoutSetting +template +struct CommonLayoutSettingSelector + : CommonLayoutSetting>, + ck::tuple_element_t>, + ck::tuple_element_t>> { }; @@ -78,10 +80,10 @@ struct ExecutionConfig final bool time_kernel = false; }; -#define DefaultConvParam \ - ck::utils::conv::ConvParam \ - { \ - 2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \ +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \ } inline void print_help_msg() diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index 375c309e1c93d30a6d7167d524613c9386510551..44ba3d8e02498ce72c6f0f4f8e9bdcaaf9ef1760 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -3,7 +3,7 @@ #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" using InDataType = F16; using WeiDataType = F16; @@ -15,45 +15,84 @@ using WeiElementOp = PassThrough; using OutElementOp = PassThrough; template -using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdWeightDefault, // ConvBackwardWeightSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // K0PerBlock - 2, // K1 - 4, // M1PerThread - 4, // N1PerThread - 1, // KPerThread - S<8, 2>, // M1N1ThreadClusterM1Xs - S<8, 2>, // M1N1ThreadClusterN1Xs - S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 - S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder - S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder - S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 - S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 - S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder - S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - 4>; // CThreadTransferDstScalarPerVector +using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + NDimSpatial, // NDimSpatial + ck::tuple_element_t>, // InLayout + ck::tuple_element_t>, // WeiLayout + ck::tuple_element_t>, // OutLayout + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 2, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder + S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder + S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c2950b17e802b9418c1614eb8c8c34168929c3b --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_wmma_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< + NDimSpatial, + ck::tensor_layout::convolution::GNDHWC, + ck::tensor_layout::convolution::GKZYXC, + ck::tensor_layout::convolution::GNDHWK, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 2, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // ABlockTransferSrcAccessOrder + 1, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 4, + 2, + S<1, 32, 1, 8>, + 1>; + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index aed6d22b023aafc35e60fac29b6375a8602b87d5..80b97249304f1cb8403246785335d4c5c98784b1 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" using InDataType = BF16; // bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory @@ -17,8 +17,20 @@ using OutElementOp = PassThrough; template using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< - NDimSpatial, // NDimSpatial + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, InDataType, // InDataType WeiDataType, // WeiDataType OutDataType, // OutDataType @@ -55,6 +67,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp index 4a2a6195d9574c26e3526d570bfddff067bd3e37..5c1c97d615df22a41a8331c800e499a1b433b047 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp @@ -1,9 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" using InDataType = F16; using WeiDataType = F16; @@ -16,8 +16,20 @@ using OutElementOp = PassThrough; template using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle< - NDimSpatial, // NDimSpatial + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, InDataType, // InDataType WeiDataType, // WeiDataType OutDataType, // OutDataType @@ -54,6 +66,34 @@ using DeviceConvBwdWeightInstance = S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + #include "run_grouped_conv_bwd_weight_example.inc" -int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0d0d58eb2a3f227dbfbc9ad077f29d4f4102e677 --- /dev/null +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" + +using InDataType = F16; +using WeiDataType = F16; +using OutDataType = F16; +using AccDataType = F32; +using ComputeTypeA = BF8; +using ComputeTypeB = F8; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = PassThrough; + +template +using DeviceConvBwdWeightInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CBlockTransferScalarPerVector_NWaveNPerXdl + ComputeTypeA, // ComputeTypeA + ComputeTypeB>; // ComputeTypeB + +template +using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; + +#include "run_grouped_conv_bwd_weight_example.inc" + +int main(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return 1; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param); + case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param); + case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param); + default: break; + } + + return 1; +} diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 7891812375f060768558556a7025bbe38458c379..73d15f3ea8c71a138a26bf9cf9aa58cd8fcf1e1c 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -1,31 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -template -using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight; +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - ck::index_t split_k; - // Set split_k = 2 for xdl op, split_k = 1 for dl - // Dl op doesn't support split_k > 1 - // TODO: Add Dl op split_k > 1 support - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) - { - split_k = 2; - } - else - { - split_k = 1; - } + // Dl and WMMA ops don't support split_k > 1 + constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< @@ -56,8 +37,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - out.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); } DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); @@ -70,9 +51,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, // init to 0 wei_device_buf.SetZero(); - std::array input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; + std::array input_strides{}; + std::array filter_lengths{}; + std::array weights_strides{}; + std::array output_lengths{}; + std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; std::array input_left_pads{}; @@ -80,9 +64,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); + range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); + range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); range_copy(conv_param.input_left_pads_, begin(input_left_pads)); @@ -94,13 +81,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -118,18 +104,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return true; } - float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - - std::size_t flop = conv_param.GetFlops(); - std::size_t num_btype = conv_param.GetByte(); - - float tflops = static_cast(flop) / 1.E9 / avg_time; - - float gb_per_sec = num_btype / 1.E6 / avg_time; - - std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl - << "DeviceOp: " << conv.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); if(config.do_verification) { @@ -153,25 +128,18 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, return ck::utils::check_err(wei_device_result.mData, wei_host_result.mData); } - return true; -} + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); -bool run_grouped_conv_bwd_weight_example(int argc, char* argv[]) -{ - ExecutionConfig config; - ck::utils::conv::ConvParam conv_param = DefaultConvParam; + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); - if(!parse_cmd_args(argc, argv, config, conv_param)) - { - return false; - } + float tflops = static_cast(flop) / 1.E9 / avg_time; - switch(conv_param.num_dim_spatial_) - { - case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param); - case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param); - case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param); - } + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cerr << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl + << "DeviceOp: " << conv.GetTypeString() << std::endl; - return false; + return true; } diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 2eb7052e1eea21ea652f088813618b15256fca10..e231bc619b781c19f447b7fa7442f2beb58331cb 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,4 +1,12 @@ -add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) -add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) -add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) -add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) + add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) + add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) + add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) + set(target 1) + endif() +endforeach() + diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 192fe87b626ff721ca106b9c633c109483f5772a..96d04dcb37798fb6db3334bd2d802e435fdc93fd 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index 3f01e6947728ecbef57734190b3896ff53899b02..6a92e9a2f50c8a5f1b80228307ad946188ea3c75 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -114,12 +114,15 @@ void host_gemm_layernorm(Tensor& h_m_n, BetaDataType, HDataType, AccDataType, + AccDataType, HElementOp, 2, 1>; Tensor e_m_n(HostTensorDescriptor{M, N}); Tensor c_m_n(HostTensorDescriptor{M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); auto ref_gemm = ReferenceGemm{}; auto ref_gemm_invoker = ref_gemm.MakeInvoker(); @@ -145,7 +148,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, save_mean, save_inv_std, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index 4da6da65f7ab21b6ffae367747ccf5f0c71232dc..bd1d6932aceb72890dcb0f7c35004cc0dccd93af 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index e7d857c4a0fa53a262243b759ddf9b33c691e26c..90d80f9f034b391f75c498f3a34232edf64f5260 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 156456115611ca12744e7773c199138733b88cd7..44585b11d04e762f58cc20c1816200857eef0fa6 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,17 +1,18 @@ add_custom_target(example_cgemm_xdl) add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) + add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) + add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) -add_dependencies(example_cgemm_xdl - example_cgemm_xdl_bf16 - example_cgemm_xdl_fp16 - example_cgemm_xdl_fp32 - example_cgemm_xdl_int8) +add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) - add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) + add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp) + add_example_dependencies(example_cgemm_xdl example_cgemm_xdl_int4) endif() diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index 92ed90ce4ab3ec09da2e053ceebe4d4ed6bcf36d..fa4482a984f20d203bd8bee68614c8e71ecbf83c 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_common.hpp b/example/22_cgemm/cgemm_xdl_common.hpp index 6aa06b7c32cb476b61a88235e13a12a3d1a15db5..26137a7c2e50d0d5010639592c0a614e2eca607b 100644 --- a/example/22_cgemm/cgemm_xdl_common.hpp +++ b/example/22_cgemm/cgemm_xdl_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 11373736ee8b37efb5f4082253d46f02074f011f..89a581e865a56b231711352f9733403ed2945aea 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index 0f45c18c4818726f179adda70dc12b5ea7c45b9d..cf9659959990823f10ef0cb30fa0eef928c32adc 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int4.cpp b/example/22_cgemm/cgemm_xdl_int4.cpp index c26a83baafd375b477845fbdd793c3918d7b7dc5..f69cc2b3cc4f98e94ddbf0b37e8ac7324515a8d6 100644 --- a/example/22_cgemm/cgemm_xdl_int4.cpp +++ b/example/22_cgemm/cgemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/22_cgemm/cgemm_xdl_int8.cpp b/example/22_cgemm/cgemm_xdl_int8.cpp index 2f24189861d8468848d28b9c0480da7c3f2c4fdd..c4835b853ee75496f963a8e14d7a0190eb4cb1cf 100644 --- a/example/22_cgemm/cgemm_xdl_int8.cpp +++ b/example/22_cgemm/cgemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp index 41afd72f5ac9ddc3dc030a914e1a26da60c27356..d09e434bcfbb262a83c5167429d41d1ca54391ef 100644 --- a/example/23_softmax/softmax_blockwise.cpp +++ b/example/23_softmax/softmax_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 7962576e875dd6b7addd659dd9b9f3b6c1504a41..4cb45be7c90182928e6444c3d923c65b28add384 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,17 +1,18 @@ add_custom_target(example_batched_gemm_xdl) add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) + add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) -add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) -add_dependencies(example_batched_gemm_xdl - example_batched_gemm_xdl_fp32 - example_batched_gemm_xdl_fp16 - example_batched_gemm_xdl_bfp16 - example_batched_gemm_xdl_int8) +add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) + add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) + add_example_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) endif() diff --git a/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp similarity index 100% rename from example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp rename to example/24_batched_gemm/batched_gemm_xdl_bf16.cpp diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index c934d35019602e634da6f8b6d49c1c3b133c6800..420a7cf74f3186ac62d5dc37346a202178d8d273 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp index 98835f98fa6ecea9c38f1155f5dd03d0787fa0d2..9d606db205dde86d544971b4b7fc4830ba73c568 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index c58751f0dd38e8fcc0df0da428777e7f3c7645b1..167a9f147d85a1dd034702dc565118425a8cf5af 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,5 +1,4 @@ add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) - add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index ea105e4ff2bb9a33830fd1182007d658c52ba7e0..78522160c85a1024ec510197e5ab1069e6f077cf 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -385,22 +251,22 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 9a000377bba2b92fe7509fc07821faae3634b5c0..6cceed5bc11ed8eb79568a0397bbf7e1b93fd33d 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -74,141 +75,6 @@ using DeviceOpInstanceMNNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKNN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -385,22 +251,22 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index 26f176b0591d02a821fa783542e9538791e4cbcd..1574f5d18fb7f5f193cd4adb1c1a88616d7b1437 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -368,22 +234,23 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 38ed60266de9e5a8322cf0fffa5fcb9eb33cbc84..3dacc708877d838dd5fb0113b5d560885b0551ae 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -16,6 +16,7 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" template using S = ck::Sequence; @@ -73,141 +74,6 @@ using DeviceOpInstanceMNN = ck::tensor_operation::device:: using DeviceOpInstance = DeviceOpInstanceKKN; -// hardcoded for NumDimM == NumDimN == NumDimK == 2 -template = false> -struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator -{ - // Argument - struct Argument : public ck::tensor_operation::device::BaseArgument - { - Argument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : a_ms_ks_{a_ms_ks}, - b_ns_ks_{b_ns_ks}, - e_ms_ns_{e_ms_ns}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - } - - const Tensor& a_ms_ks_; - const Tensor& b_ns_ks_; - Tensor& e_ms_ns_; - - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public ck::tensor_operation::device::BaseInvoker - { - using Argument = ReferenceContraction_M2_N2_K2::Argument; - - float Run(const Argument& arg) - { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; - - AccDataType v_acc = 0; - - for(int k0 = 0; k0 < K0; ++k0) - { - for(int k1 = 0; k1 < K1; ++k1) - { - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_( - v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); - arg.b_element_op_( - v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); - - v_acc += v_a * v_b; - } - } - - AccDataType v_c; - - arg.cde_element_op_(v_c, v_acc); - - arg.e_ms_ns_(m0, m1, n0, n1) = v_c; - }; - - make_ParallelTensorFunctor(f_ms_ns, - arg.e_ms_ns_.mDesc.GetLengths()[0], - arg.e_ms_ns_.mDesc.GetLengths()[1], - arg.e_ms_ns_.mDesc.GetLengths()[2], - arg.e_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); - - return 0; - } - - float Run(const ck::tensor_operation::device::BaseArgument* p_arg, - const StreamConfig& /* stream_config */ = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override - { - return true; - } - - static auto MakeArgument(const Tensor& a_ms_ks, - const Tensor& b_ns_ks, - Tensor& e_ms_ns, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - virtual std::unique_ptr MakeInvokerPointer() - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "ReferenceContraction_M2_N2_K2" - << std::endl; - // clang-format on - - return str.str(); - } -}; - int main(int argc, char* argv[]) { bool do_verification = true; @@ -368,22 +234,23 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; - - auto ref_gemm = ReferenceOpInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index d96deae45e49acdc8b86f9973417d6a77e1e7e31..94c23ce77499e0c110301a3d3ae82b0ec119a7c7 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_layernorm_blockwise layernorm_blockwise.cpp) +add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) +add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) diff --git a/example/27_layernorm/common.hpp b/example/27_layernorm/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62a71713df84351fea902cdcf3275786b60f5a23 --- /dev/null +++ b/example/27_layernorm/common.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" diff --git a/example/27_layernorm/layernorm_blockwise.cpp b/example/27_layernorm/layernorm_blockwise.cpp deleted file mode 100644 index 7d91b69d04767e00a3cba38f60f43153bdb0d036..0000000000000000000000000000000000000000 --- a/example/27_layernorm/layernorm_blockwise.cpp +++ /dev/null @@ -1,139 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/utility/reduction_enums.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_common_util.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" - -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -constexpr int Rank = 2; -constexpr int NumReduceDim = 1; - -using DeviceInstance = - ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector - -int main() -{ - bool time_kernel = false; - - ck::index_t M = 1024; - ck::index_t N = 1024; - ck::index_t Stride = N; - - auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { - return HostTensorDescriptor({len}, {stride}); - }; - - auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { - using namespace ck::literals; - - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - }; - - Tensor x(f_host_tensor_descriptor2d(M, N, Stride)); - Tensor gamma(f_host_tensor_descriptor1d(N, 1)); - Tensor beta(f_host_tensor_descriptor1d(N, 1)); - Tensor y(f_host_tensor_descriptor2d(M, N, Stride)); - - x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); - - x_dev.ToDevice(x.mData.data()); - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = DeviceInstance{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - {M, N}, - std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, - {0, 1}, - {0, 1}, - std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, - {1}, - 1e-4, - x_dev.GetDeviceBuffer(), - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - y_dev.GetDeviceBuffer(), - nullptr, - nullptr, - PassThrough{}); - - if(!device_instance.IsSupportedArgument(argument_ptr.get())) - { - std::cout << "The runtime parameters are not supported" << std::endl; - return 1; - }; - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); - using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; - - ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, PassThrough{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - y_dev.FromDevice(y.mData.data()); - pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results d1", 1e-3, 1e-3); - } - return (pass ? 0 : 1); -} diff --git a/example/27_layernorm/layernorm_fp16.cpp b/example/27_layernorm/layernorm_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..255452e7691562eeae739ae0b019366a707d6ee6 --- /dev/null +++ b/example/27_layernorm/layernorm_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationImpl; // SaveMeanInvStdScalarPerVector +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/layernorm_splitk_fp16.cpp b/example/27_layernorm/layernorm_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e2a85bddc5ccce08d4946734353f5d4be10edcae --- /dev/null +++ b/example/27_layernorm/layernorm_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +#define SAVE_MEAN_INV_STD + +constexpr int Rank = 2; +constexpr int NumReduceDim = 1; + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // SaveMeanInvStdScalarPerVector + +#include "run_layernorm_example.inc" + +int main() { return run_groupnorm_example(); } diff --git a/example/27_layernorm/run_layernorm_example.inc b/example/27_layernorm/run_layernorm_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..399165c36eedbd56e6a0f4fc6e43c53fa6054774 --- /dev/null +++ b/example/27_layernorm/run_layernorm_example.inc @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +int run_groupnorm_example() +{ + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + + Tensor x({M, N}); + Tensor gamma({N}); + Tensor beta({N}); + Tensor y({M, N}); + Tensor save_mean({M}); + Tensor save_inv_std({M}); + + x.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif + + x_dev.ToDevice(x.mData.data()); + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + {M, N}, + std::vector{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()}, + {0, 1}, + {0, 1}, + std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + {1}, + 1e-4, + x_dev.GetDeviceBuffer(), + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else + nullptr, + nullptr, +#endif + PassThrough{}); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters are not supported" << std::endl; + return 1; + }; + + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor host_y({M, N}); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); + + using ReferenceInstance = + ck::tensor_operation::host::ReferenceLayernorm; + + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + PassThrough{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + y_dev.FromDevice(y.mData.data()); + pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results (y)", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif + } + + return (pass ? 0 : 1); +} diff --git a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp index f8e6501eadaee4e5b4156309afebdf16e11f3e8a..24e9b1d9b7d569adf20caf1bd0f0cddafdbd256b 100644 --- a/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 30ad38a566309dcfacd52f100099d8a3bf02f77f..62233e535151e6b647e938de04b6225fe487b6c4 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp index 25d815b9cdfdddc29c699a1bc51dd5191a5246eb..08158bfc250642e032d20fad940bb8853c71d446 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 4b0ea4f1579694250bcd04f0c4374728eadbf1ff..3a8c2ef52fac50d10bf2d1167ef1d4819281d312 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,25 +1,40 @@ -add_custom_target(example_grouped_conv_fwd_multiple_d) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) -add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_custom_target(example_grouped_conv_fwd_multiple_d) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) -endif() # USE_BITINT_EXTENSION_INT4 + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) -endif() + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) -add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) -add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + endif() # USE_BITINT_EXTENSION_INT4 + + set(target 1) + endif() +endforeach() + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/30_grouped_conv_fwd_multiple_d/README.md b/example/30_grouped_conv_fwd_multiple_d/README.md index 739a0425a8c29419c3e31e433815c5230bd1a5cb..4718795c576c0f773fa92fadb54a041e884851c5 100644 --- a/example/30_grouped_conv_fwd_multiple_d/README.md +++ b/example/30_grouped_conv_fwd_multiple_d/README.md @@ -4,7 +4,7 @@ arg1: verification (0=no, 1=yes) arg2: initialization (0=no init, 1=integer value, 2=decimal value) arg3: time kernel (0=no, 1=yes) Following arguments (depending on number of spatial dims): - Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d) + Number of spatial dimensions (1=Conv1D, 2=Conv2D, 3=Conv3D) G, N, K, C, , (ie Y, X for 2D) , (ie Hi, Wi for 2D) diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index e7c6ed9b939abb0e8f593d54d5b45c499db5f2c0..e60ebee6e4fe2e08db432816b75bbbd4c388946c 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp index eb6975a6d812aef66f1caf493dfe82549a0e4ae9..ae769ff1d38306f4feed8d3cb00dde79b3f561eb 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp index 9d1d257a2889dcdaefb697faa6bcf6c172c55839..039d25029921491e7e67808e554b8cb3e6eb4745 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common_wmma.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..793324970e700748ac4b18e9fc429a24f70c007d --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common_wmma.hpp" + +// kernel data types +using InKernelDataType = I8; +using WeiKernelDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I8; +using BiasKernelDataType = I8; +using ResidualKernelDataType = I8; +using OutKernelDataType = I8; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp index ee300d073a28b44cf8177f64a7278d5e13ef3130..43c0d57dc2adcae69ecde327137eb204b71fa44d 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp index 5a9df0b1e880c8f3bc954c07cd5de3639644c4e1..40b4132b358d8259146c2987472dc7e9d6c71091 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp index c2906cc9dd1dbe22c8d420ce0929783fb0be0fa6..e05d384f26b40273a6bfe77222b805289a99e631 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp index 3d5a243e6b9282f9d362c82c0758776341df646f..5494563fdd568d1f51ec9ee9042d094096e8fc8a 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp index eaf680fa438a14ee402e3805c18692ffdcf78c7e..6bf2e8d963c1524bfec0680409580b47e5394bc5 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc index 4561156e0bd3a935c49751802e8c3037d8735228..eb242203eaa65ba7b8fca45c5f1e3e1b9d96c409 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index a6888649c0ed0adc6f2084c428b2c847cd952f77..360b2c8947b371b1502afb3af73c469f0d8c3636 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index d087c31af5def501635ef173769851e111dec3cb..58ed69182e0622751dc3fa2823fcb7cfa5b279a8 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. template using DeviceConvFwdInstance = diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index ad40c96b418802fd69f337db3667e64ca1922639..93f16c945f22f4180d6a79477be3fa0aca130e5e 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,10 +1,18 @@ -add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) -add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) -if(NOT GPU_TARGETS MATCHES "gfx940") - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) -endif() +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) + +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() -if(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) +endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp index 74e0e07e62a99e0700213e80a066d7e15128e516..7605d9c4f8368e369e4c29167f90620984b5a6c8 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp index d5fadb8081e86a83784f6fad7e7405a7808b65b9..33ed04fb3068b1bf83eb128825ccad355c6de4f9 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp index 0dd4e0914f479eefae604ceba8115432ce1d3803..e0eb193ad0484c62ac6fa5695ac43aa17c509227 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp index 1fd93622a1b14a72d8159517a84b9f19d2959723..d166214c3376cd90afd11796a4cb85ea28421861 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp index 15d98abab7dfa0f1168fd411eeae481054053000..40f87d1f554c4ef2a73ceb32263c3499f733b40a 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 7e5f1614bcf6c6133d88f9d8dad63ec4be47c9ac..f329146728dd8df3c38af3d2520ac91d43d12d00 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 8d9aaec85a503b060254ef748db77f4a7f289dae..2a24abf094b09d40b7feef6e9a46a782dc7610d8 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,16 +1,23 @@ +add_custom_target(example_gemm_scale_softmax_gemm) + add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -add_custom_target(example_gemm_scale_softmax_gemm) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 0eb15653306f0179b1d132ddfcd130af6ba13fd6..1d1566d57561afaa84be2de34b11416b5da0571c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp index 8f1db577c604418b7b1d546b5de251f5a20054a6..bae88d4b8e6f786219e8d68bd0a34f6036dd0d72 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 2ce91a8c6023314d61b354614c81ca74c5f95727..a098ce6675e05308fd9de223e807a0e14504a7f5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp index 1fd2bf69306f5b35dd4ea2ae484c677c6b3c63fc..ce8caf758842225098baac92876ea6996321ff21 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index f4a8589052f0f149902553884a58cc4c6c79d030..138db14963809cd49294d1cfc2171b038dd10c60 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index e4a71b04313446f75f4cbc73456a66bdbb62c47e..57949242941d7f750d1b121712f1d5cca6cee4d5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 38b5badc6e4f270e409ed98c492035c5bebe00c3..97caec6053bb7846dd53f37b26c3945a55812967 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 4e43dbdd8fc5a83486cd052ee3e789e05152861f..27602e2313f7aa197e88e1fabeb39245e2fdf5eb 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 0b876af952f65f89e9ae55bea344bfa6f61afd5d..fa76faea84e4551ddf8d0617c132dbe7a6045fb3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index ef2acf61f55fa7f108e174d4586c13ae3dc7419d..ea1e2734a684b61a363eb93ef0e2ff933f900ea5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. int run(int argc, char* argv[]) { diff --git a/example/33_multiple_reduce/dual_reduce_common.hpp b/example/33_multiple_reduce/dual_reduce_common.hpp index 326606752b25394d504cb054cf3026173c66e970..cd21790be6548abd0f8097613ee650d4407fc400 100644 --- a/example/33_multiple_reduce/dual_reduce_common.hpp +++ b/example/33_multiple_reduce/dual_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_multiblock.cpp b/example/33_multiple_reduce/dual_reduce_multiblock.cpp index 9360599ed9e8a7517bc244e2942d6949c1f71e61..198931749b19b8e8d91bf16c1a733801cb27aeba 100644 --- a/example/33_multiple_reduce/dual_reduce_multiblock.cpp +++ b/example/33_multiple_reduce/dual_reduce_multiblock.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/33_multiple_reduce/dual_reduce_threadwise.cpp b/example/33_multiple_reduce/dual_reduce_threadwise.cpp index 56255839e567fa0386ba09b80738f5d0d2abcc23..7609edad3527e8219c2ef6351381d6e53455a67b 100644 --- a/example/33_multiple_reduce/dual_reduce_threadwise.cpp +++ b/example/33_multiple_reduce/dual_reduce_threadwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/CMakeLists.txt b/example/34_batchnorm/CMakeLists.txt index d964f40d877ddb5dc6e03cf8d28897e244514d96..60824c5f4d7d742941721085be6fc22d638dffe9 100644 --- a/example/34_batchnorm/CMakeLists.txt +++ b/example/34_batchnorm/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp) +add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp) add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp) add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp) diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index a6ca9d150bd918966bae06a63ee4eb9da6be5f3f..3756310fd7dd47b111b9fde743cac844ae2c7669 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_common.hpp b/example/34_batchnorm/batchnorm_common.hpp index bdbc8ea8b88f40de2f25a2ec8c5a74ab5e38fd74..a1b8d253bf061c02511945b98eac9343c0d1a236 100644 --- a/example/34_batchnorm/batchnorm_common.hpp +++ b/example/34_batchnorm/batchnorm_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index dc2984851a02e0020c8c4cfa6d59e61bfe49c593..6a8002025a60b277c587190a95855cc262c33e17 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index da36d65a2954ee1ac3a94c617a219e6a7c2baf44..b27358fd9de4547dc8fa8532981f097713eb1e07 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification, (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); y_dev.FromDevice(y.mData.data()); - pass = pass && ck::utils::check_err(y, y_ref); + pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values"); if(updateMovingAverage) { @@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification, resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); - pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref); - pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref); + pass = pass && ck::utils::check_err(resultRunningMean, + resultRunningMean_ref, + "Incorrect running mean values"); + pass = pass && ck::utils::check_err(resultRunningVariance, + resultRunningVariance_ref, + "Incorrect running variance values"); }; if(saveMeanAndInvVariance) @@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); - pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref); - pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref); + pass = pass && ck::utils::check_err( + resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values"); + pass = pass && ck::utils::check_err(resultSaveInvVariance, + resultSaveInvVariance_ref, + "Incorrect saved invvariance values"); }; }; diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ffb9f4b58468699c1a34ff6aff101c2dd4adaf66 --- /dev/null +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -0,0 +1,598 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp" +#include "ck/library/utility/host_common_util.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class BatchNormFwdArg +{ + private: + int option_index = 0; + + public: + std::vector inOutLengths; + + bool do_verification = false; + + bool updateMovingAverage; + bool saveMeanAndInvVariance; + + int data_type = 0; + int init_method = 2; + bool time_kernel = false; + bool use_multiblock_welford = false; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension " + "lengths, must have 4 integers for nhwc" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization " + "result by " + "comparing with the host-based batch-normalization" + << std::endl; + std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl; + std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance " + "(0=no, 1=yes)" + << std::endl; + std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer " + "value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inOutLengths = getTypeValuesFromString(optarg); + + if(inOutLengths.size() != 4) + throw std::runtime_error( + "NHWC tensor layout should have 4 length values specified!"); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 6 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + data_type = std::atoi(argv[optind++]); + updateMovingAverage = std::atoi(argv[optind++]); + saveMeanAndInvVariance = std::atoi(argv[optind++]); + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind++])); + use_multiblock_welford = static_cast(std::atoi(argv[optind])); + + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) + return (-1); + + return (0); + }; +}; + +using namespace ck; + +template +bool bnorm_fwd_nhwc_test(bool do_verification, + int init_method, + bool time_kernel, + const std::vector inOutLengths, + bool updateMovingAverage, + bool saveMeanAndInvVariance, + double averageFactor, + double epsilon) +{ + // for NHWC BatchNorm calculation of mean and meansquare + constexpr int Rank = 4; + constexpr int NumReduceDim = 3; + + // when using lengths[] to create a tensor, lengths[0] is the length of highest dimension + // eg. N of NHWC, so lengths[3] is the dimension C length of NHWC + const std::vector scaleBiasMeanVarLengths = {inOutLengths[3]}; + + // input data of the batchnorm forward algorithm + Tensor x(inOutLengths); + Tensor bnScale(scaleBiasMeanVarLengths); + Tensor bnBias(scaleBiasMeanVarLengths); + + // output data of the batchnorm forward algorithm + Tensor y_ref(inOutLengths); + Tensor y(inOutLengths); + + Tensor resultSaveMean_ref(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance_ref(scaleBiasMeanVarLengths); + + Tensor resultRunningMean_ref(scaleBiasMeanVarLengths); + Tensor resultRunningVariance_ref(scaleBiasMeanVarLengths); + + auto inOutStrides = x.mDesc.GetStrides(); + auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides(); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(updateMovingAverage) + { + if constexpr(std::is_same::value) + { + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + + const float x_mean = 0.0f; + const float x_stddev = 2.5f; + const float noise_stddev = 0.04f; + + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + } + else + { + const float x_mean = 0.0f; + const float x_stddev = 1.0f; + const float noise_stddev = 0.04f; + + // input data in normal distribution + x.GenerateTensorValue(GeneratorTensor_4{x_mean, x_stddev}, num_thread); + + // initialize the runningMean to be values with tiny variation to the mean of the x + // values + resultRunningMean_ref.GenerateTensorValue( + GeneratorTensor_4{x_mean, noise_stddev}, num_thread); + + // initialize the runningVariance to be values with tiny variation to the variance of + // the x values + resultRunningVariance_ref.GenerateTensorValue( + GeneratorTensor_4{x_stddev * x_stddev, noise_stddev}, num_thread); + }; + } + else + { + if constexpr(std::is_same::value) + x.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + else + x.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + }; + + if(do_verification) + { + switch(init_method) + { + case 0: + bnScale.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + break; + case 1: + bnScale.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_1{0}, num_thread); + break; + case 2: + bnScale.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + bnScale.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + bnBias.GenerateTensorValue(GeneratorTensor_3{-5.0f, 5.0f}, num_thread); + } + }; + + // these buffers are usually provided by the user application + DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize()); + DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize()); + + // mean_dev or resultSaveMean_dev + DeviceMem resultSaveMean_dev(sizeof(AccDataType) * + resultSaveMean_ref.mDesc.GetElementSpaceSize()); + // meansquare_dev or resultSaveInvVariance_dev + DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) * + resultSaveInvVariance_ref.mDesc.GetElementSpaceSize()); + // resultRunningMean_dev + DeviceMem resultRunningMean_dev(sizeof(AccDataType) * + resultRunningMean_ref.mDesc.GetElementSpaceSize()); + // resultRunningVariance_dev + DeviceMem resultRunningVariance_dev(sizeof(AccDataType) * + resultRunningVariance_ref.mDesc.GetElementSpaceSize()); + + x_dev.ToDevice(x.mData.data()); + bnScale_dev.ToDevice(bnScale.mData.data()); + bnBias_dev.ToDevice(bnBias.mData.data()); + + if(updateMovingAverage) + { + resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data()); + resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data()); + }; + + std::array i_inOutLengths; + std::array i_inOutStrides; + std::array i_scaleBiasMeanVarLengths; + std::array i_scaleBiasMeanVarStrides; + + ck::ranges::copy(inOutLengths, i_inOutLengths.begin()); + ck::ranges::copy(inOutStrides, i_inOutStrides.begin()); + ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin()); + ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin()); + + using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; + + using DeviceBatchNormFwdInstance = + ck::tensor_operation::device::DeviceBatchNormFwdImpl; + + auto batchnorm_fwd = DeviceBatchNormFwdInstance{}; + + auto argument_ptr = batchnorm_fwd.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x_dev.GetDeviceBuffer(), + bnScale_dev.GetDeviceBuffer(), + bnBias_dev.GetDeviceBuffer(), + epsilon, + PassThroughOp{}, + y_dev.GetDeviceBuffer(), + saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr, + updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr); + + if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, " + "exiting!" + << std::endl; + return (false); + }; + + size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get()); + + DeviceMem workspace_dev(workspace_sz); + + batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer(); + + if(time_kernel) + { + float avg_time = 0.0f; + size_t num_bytes = 0; + + size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3]; + size_t invariant_length = inOutLengths[3]; + + avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + // inputing of x, scale, bias, outputing of y + num_bytes += + total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2; + + // outputing of mean, inv-variance + num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0; + + // updating of moving mean, variance + num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0; + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + (void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + + if(do_verification) + { + + using ReferenceBatchNormFwdInstance = + ck::tensor_operation::host::ReferenceBatchNormFwd; + + auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{}; + + auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer( + i_inOutLengths, + i_inOutStrides, + i_inOutStrides, + {0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[] + i_scaleBiasMeanVarLengths, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + i_scaleBiasMeanVarStrides, + x.mData.data(), + bnScale.mData.data(), + bnBias.mData.data(), + epsilon, + PassThroughOp{}, + y_ref.mData.data(), + saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr, + saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr, + averageFactor, + updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr, + updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr); + + if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get())) + { + std::cout << "The runtime parameters seems not supported by the BatchNorm reference " + "instance, exiting!" + << std::endl; + return (false); + }; + + auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer(); + + (void)invoker_ptr_ref->Run(argument_ptr_ref.get()); + + y_dev.FromDevice(y.mData.data()); + pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values"); + + if(updateMovingAverage) + { + Tensor resultRunningMean(scaleBiasMeanVarLengths); + Tensor resultRunningVariance(scaleBiasMeanVarLengths); + + resultRunningMean_dev.FromDevice(resultRunningMean.mData.data()); + resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data()); + + pass = pass && ck::utils::check_err(resultRunningMean, + resultRunningMean_ref, + "Incorrect running mean values"); + pass = pass && ck::utils::check_err(resultRunningVariance, + resultRunningVariance_ref, + "Incorrect running variance values"); + }; + + if(saveMeanAndInvVariance) + { + using ck::host_common::dumpBufferToFile; + + Tensor resultSaveMean(scaleBiasMeanVarLengths); + Tensor resultSaveInvVariance(scaleBiasMeanVarLengths); + + resultSaveMean_dev.FromDevice(resultSaveMean.mData.data()); + resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data()); + + pass = pass && ck::utils::check_err( + resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values"); + pass = pass && ck::utils::check_err(resultSaveInvVariance, + resultSaveInvVariance_ref, + "Incorrect saved invvariance values"); + }; + }; + + return (pass); +}; + +const double epsilon = std::numeric_limits::epsilon(); +static const double averageFactor = 0.1; + +int main(int argc, char* argv[]) +{ + bool pass = true; + + if(argc > 1) + { + BatchNormFwdArg arg; + + if(arg.processArgs(argc, argv) < 0) + return (-1); + + if(arg.data_type == 0) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 1) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 3) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 5) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + else if(arg.data_type == 6) + { + if(arg.use_multiblock_welford) + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + else + pass = bnorm_fwd_nhwc_test(arg.do_verification, + arg.init_method, + arg.time_kernel, + arg.inOutLengths, + arg.updateMovingAverage, + arg.saveMeanAndInvVariance, + averageFactor, + epsilon); + } + } + else + { + pass = bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 6, 512}, + true, + true, + averageFactor, + epsilon); + + pass = pass && bnorm_fwd_nhwc_test(true, + 2, + false, // don't time kernel + {128, 16, 3, 1024}, + true, + true, + averageFactor, + epsilon); + }; + + return (pass ? 0 : 1); +} diff --git a/example/34_batchnorm/batchnorm_infer_impl.hpp b/example/34_batchnorm/batchnorm_infer_impl.hpp index 15170586b636825325c796cdc20eb70996a24bb7..d0b545b2a31d2fc095eece12ce273db823a5d2ac 100644 --- a/example/34_batchnorm/batchnorm_infer_impl.hpp +++ b/example/34_batchnorm/batchnorm_infer_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 794583954674f403a4902a6fc1cb244b803479d5..eff6b6f3fa4f648f04f5bd7ce2da8d18814c0081 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,17 +1,26 @@ -add_custom_target(example_splitK_gemm_xdl) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_splitK_gemm_xdl) -add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) -add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) -add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) -add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) -add_dependencies(example_splitK_gemm_xdl - example_splitK_gemm_xdl_fp32 - example_splitK_gemm_xdl_fp16 - example_splitK_gemm_xdl_bfp16 - example_splitK_gemm_xdl_int8) + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) -if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) -endif() + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) + + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) + + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) + endif() + + set(target 1) + endif() +endforeach() diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp similarity index 91% rename from example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp rename to example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp index 7191ecf50ab5252b37855072b4e03f1194df52f6..fdf49a31b719c2aa0e7d6441e52c4dd1475d4f95 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -33,6 +33,7 @@ using ADataType = BF16; using BDataType = BF16; using AccDataType = F32; using CDataType = F32; +using ComputeType = BF16; using ALayout = Row; using BLayout = Col; @@ -46,11 +47,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index efdb315b4e5576aef8d992f6d57c231ac1f83a81..74fb16e15b019e145cdf13549254be1b09fbd8ca 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp index bc2e3d1d52b668a2ea15c9f4ea1db44c4faeb19e..7506f694204b3c3aaa564704eda4f358baf17861 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp index 4eb27824628d8f193ca6108a589794041fc2b2fb..7ebf9144082de53d64f3662e579dd6d246266f1d 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp index eefdbca6b1ae3e49cc221b3037e46888b4a9a9bd..6b0c1aa02d05d848d8280510fa7d05de8b7d4287 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -30,6 +30,7 @@ using ADataType = int8_t; using BDataType = int8_t; using AccDataType = int32_t; using CDataType = int32_t; +using ComputeType = int8_t; using ALayout = Row; using BLayout = Col; @@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle // clang-format off -//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>; +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index f0a0cdf6f13df7081d15ab0099122542ccf6118e..d2337dcda5d30eb57e0e5f764f047cf21314a190 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 071e8a7431c057a32ea0061655d81a202774b923..36dcf58d7044b3bf54d3e318e5d095a40d70255d 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] @@ -173,6 +173,8 @@ using DeviceGemmInstance = 8, 8, true, + 9, // D0sTransferSrcVectorDim + 4, // D0sTransferSrcScalaerPerVector S<8, 32, 1>, // B1BlockTransfer S<0, 2, 1>, S<0, 2, 1>, @@ -189,7 +191,7 @@ int main(int argc, char* argv[]) { bool do_verification = true; int init_method = 1; - bool time_kernel = false; + bool time_kernel = true; // GEMM shape ck::index_t M = 1024; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 9cf960c501ccb279febc8205505ca0edfdeefac4..1ae179e950db0fea8e9f008d449cb8bfe23b54fd 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,7 +1,27 @@ -add_custom_target(example_grouped_conv_bwd_data) +list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) -add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) -add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) -add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) -add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) + + set(target 1) + endif() +endforeach() + +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) + add_custom_target(example_grouped_conv_bwd_data) + + add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) + add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) + + set(target 1) + endif() +endforeach() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index d07ee7bdc1c4b999f61a84727bb10fe7ea6f4c15..ebb1c606cb6c61de2ce41c846c06b1f1f200f1c4 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp similarity index 95% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index 55ea8c3a3109469532d0e5c0566da3eb11e7353d..4c28e25e019d41d2a3cb55e410286ab616dc4722 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5baa521501704a2ea603f5b975cd6b392285cfb5 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + //| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < 2,OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp similarity index 95% rename from example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp rename to example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index ddf82ec512c61c72a1a2e2323a6645f39f15ce6d..b1554412b1adca12cd49231a1c59cd6f34462ebf 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -1,6 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" using OutDataType = FP16; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0afd8bd70da849085db6e19583e933c3d1b91968..0f0b120cbcc578b2fa7f650c37c9a28148dd834a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc index e50c98bbe844f6ec438c781b7d5d8e254e57d98f..25678491ce90cdf1ab9455f14e091bd5fcd67431 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. bool run_conv_bwd_data(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params, diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 573ad7239e608e7ba947f9ca864a0c7f4da4689b..8b850c89a935ed194cad918a2bdbcf4b0c77d739 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,9 +1,10 @@ add_custom_target(example_permute) add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) +add_example_dependencies(example_permute example_permute_1xHxW_fp16) + add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) -add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_dependencies(example_permute example_permute_NxHxW_fp16) -add_dependencies(example_permute example_permute_1xHxW_fp16) -add_dependencies(example_permute example_permute_NxHxW_fp16) -add_dependencies(example_permute example_permute_HxWx4_fp16) +add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_example_dependencies(example_permute example_permute_HxWx4_fp16) diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index ab612cea1794c422c737a02e5bbe7a6728904c40..54f3a788097bebc6e293389cecd03d8ae68d8494 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/permute_1xHxW_fp16.cpp b/example/39_permute/permute_1xHxW_fp16.cpp index d7f9b80544a452ca7ea062d7d6135b92a2bf19ea..7336c3b631bcadd7d23a2d6b8e0b3d3a31787516 100644 --- a/example/39_permute/permute_1xHxW_fp16.cpp +++ b/example/39_permute/permute_1xHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_HxWx4_fp16.cpp b/example/39_permute/permute_HxWx4_fp16.cpp index 342aa134ec5570c84447b137e5be165c9a5f697c..6c24919ded6a0c65556e305455861ca137203b3e 100644 --- a/example/39_permute/permute_HxWx4_fp16.cpp +++ b/example/39_permute/permute_HxWx4_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/permute_NxHxW_fp16.cpp b/example/39_permute/permute_NxHxW_fp16.cpp index b53975eb2c8632583f61afde50a03eba32c17c9b..3551d2a7c8decf31a48ca628d99a53639f4fa1fe 100644 --- a/example/39_permute/permute_NxHxW_fp16.cpp +++ b/example/39_permute/permute_NxHxW_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/39_permute/run_permute_bundle_example.inc b/example/39_permute/run_permute_bundle_example.inc index 70406d63f91b95f4d1cce025051a74a0ee3d114e..2c198729226243fb1fb97db6beefa1fd1f62f625 100644 --- a/example/39_permute/run_permute_bundle_example.inc +++ b/example/39_permute/run_permute_bundle_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/39_permute/run_permute_element_example.inc b/example/39_permute/run_permute_element_example.inc index bc6235303039c7cb33a4057446137a9f18ddb186..35871344567ddf62a4fb3c2bd6a2e77cb94a1998 100644 --- a/example/39_permute/run_permute_element_example.inc +++ b/example/39_permute/run_permute_element_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 0a314cd74c28bb8fc0a873b22db4498444db7064..2d804cafe9e7c5cc8cb2c6ca783fc9ffc8b1ab72 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,21 +1,24 @@ -# Conv perlayer quantization -add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) + add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) + set(target 1) + endif() +endforeach() -# Conv perchannel quantization -add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - -# Conv + bias + relu perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - -# Conv + bias + relu perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) -add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - -# Conv + bias + tanh perlayer quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - -# Conv + bias + tanh perchannel quantization -add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) + # Conv perlayer quantization + add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) + # Conv perchannel quantization + add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) + # Conv + bias + relu perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) + # Conv + bias + relu perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) + # Conv + bias + tanh perlayer quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) + # Conv + bias + tanh perchannel quantization + add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/40_conv2d_fwd_quantization/common.hpp b/example/40_conv2d_fwd_quantization/common.hpp index 6ee14d750ef6c9f8308ff7b17b86626fc70ca2b8..266b09145c84c456aedb2182554ced7815457410 100644 --- a/example/40_conv2d_fwd_quantization/common.hpp +++ b/example/40_conv2d_fwd_quantization/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index 5c445d9c50b6ea3235f045a7e569aaf87fb048ea..4573c68658bf44dc49ccade302feb1547746f0f9 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 0ff85f008fa6fd7101436ec55c02c1e1fbae79cf..005f6263fd46e8b505fbc6859a5fa9ed334e7229 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp index f8f996d17e82948076cbb0fa52ff3c378f7cd7cb..62e5e583de864524d2e937683a8c60cdcf9d57da 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp index 3b25fec0c4a1bcbf89aacd2c1c506358a09000be..ef98fe7e4f945bab663b7e5ee895633186919df9 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index a98a1e240bcbf390cfaed0ec9230f10d88ed90ce..e524ddb2b297ae218eaf197bea1f149e01a76efc 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp index 262594d58b3d3e591ab38209d86c3b683bd7680e..d29a3143c0ed4196a64ade05af0dc295a358c527 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" using InDataType = int8_t; using WeiDataType = int8_t; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index 6b22055053d3f45ce6f22e03bff25e5b5b7484fa..06c839e4e226e082e9754f12ce39cececbd36269 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index 1ac8679743777db2c5d99370b8b22ea8004bc8d6..7a9b42d39f971ab1c36476f4c0ead172cb6d2606 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index f28abe5ebc92d179fd65367b028e76bd4de0cbe4..3495636297d5d83fb3ccc7d236e6135de9b715fd 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index f468e8adcde28b7fa0099405b76dca7b9d29ce64..2611337254212b326f6eef540c6a4eca5cd19261 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 1587c614da8b940b352dbe702c66cde6f267742b..e5b924ad5114a6e31a1ba2875118085329d5e8ce 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once template (conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 455e0804d4b44964fd98bce14564ad432f9f74bd..9f3a769dcf0979d407169610f5cdbe15dbccb26e 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -178,10 +178,10 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using BiasLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 8e75c27746a7dbcb07001c10087a3afcf66ad7e4..9b08fc690d8298b87ce3ebf89f462f3739734815 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -180,10 +180,10 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 926c033c58dc2af6df73643ab437ef4b64a4351b..267c737e00fc5046e06551ca3e366d01dc78d230 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -162,9 +162,9 @@ int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 4eb79371a75aa29c6206e8d96aa07a40ba3a525e..ae251e88d205a04968eab430b52c004a1136d853 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,9 +1,18 @@ -add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) -add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) -add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) -if(NOT GPU_TARGETS MATCHES "gfx940") - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) +list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) +list(APPEND gpu_list2 gfx908 gfx90a) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list1 AND target EQUAL 0) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) + if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) + endif() +endforeach() + +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() -if(USE_BITINT_EXTENSION_INT4) -add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) -endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp index 2aea08c40004fac22ea0e45521ad09fd806751d8..e37d41369599819bdf1b9836c1a0587be2e31a0f 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp index b7f80e76d6ce5a690beef4f48361e0a65402cb47..496e676a402f4d12a325c2afef56ec422cb4d721 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp index 15e460948ef528b697a4e3bad7d40ca332f85c15..35d50721dcea27b2bfb42eea6e39da08709dfc7d 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp index 2cc4c07c0d89fc83b233b50967cd595dc0e53ac7..80f6e9ae05712b1df4b05b1bb75fbff62d9ca70a 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #error Should compile this file with ck::int4_t support diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp index 40ff0f69cc14cd644361f3391eea8ed52380adc0..3ade6c811ae5583bbb8a2a3b277c6f4e7f265a87 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc index a2c97f4d421f9382aa279b257a49299c4016e2c3..0722d497d8df0ebf0acf0cf798215bebcdee6e3a 100644 --- a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc +++ b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index a9990c5d890753bd07cf41befd8e955411ff6cb9..e8c306ac582f56607c8bfa309c4a3f16cd64be50 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1,2 +1,3 @@ add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) diff --git a/example/42_groupnorm/common.hpp b/example/42_groupnorm/common.hpp index e159abf3e941e761c8be0a0dd66c46f6cbe347d2..c8f91eb53b5b943bfbdd264b842cf7762a176b82 100644 --- a/example/42_groupnorm/common.hpp +++ b/example/42_groupnorm/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,6 +12,7 @@ #include "ck/ck.hpp" #include "ck/utility/reduction_enums.hpp" #include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/library/utility/fill.hpp" diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp index b07a26c4c937b7e26420d824b5d00ea7d85b4f36..0ede570e62c9ef71a37e1e52d66b49a04f6d5ce2 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -1,31 +1,38 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; + +#define SAVE_MEAN_INV_STD struct YElementOp { - template - __host__ __device__ void operator()(T& y, const T& x) const + template + __host__ __device__ void operator()(Y& y, const X& x) const { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, "Data type is not supported by this operation!"); - T a; + X a; ck::tensor_operation::element_wise::Sigmoid{}(a, x); - y = x * a; + y = ck::type_convert(x * a); }; }; @@ -35,6 +42,7 @@ using DeviceInstance = BetaDataType, ComputeDataType, YDataType, + SaveMeanInvStdDataType, YElementOp, Rank, NumReduceDim, @@ -49,7 +57,8 @@ using DeviceInstance = 2, // GammaScalarPerVector 1, // BetaVecDim (0=M, 1=K) 2, // BetaScalarPerVector - 2>; // OutScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_groupnorm_example.inc" diff --git a/example/42_groupnorm/groupnorm_splitk_fp16.cpp b/example/42_groupnorm/groupnorm_splitk_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f56268e029d6354cbaae653813b11dfc5317318 --- /dev/null +++ b/example/42_groupnorm/groupnorm_splitk_fp16.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +constexpr int Rank = 5; +constexpr int NumReduceDim = 3; + +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD + +using DeviceInstance = + ck::tensor_operation::device::DeviceNormalizationSplitKImpl; // SaveMeanInvStdScalarPerVector + +#include "run_groupnorm_example.inc" + +int main(int argc, char* argv[]) { run_groupnorm_example(argc, argv); } diff --git a/example/42_groupnorm/groupnorm_swish_fp16.cpp b/example/42_groupnorm/groupnorm_swish_fp16.cpp index c52243bfb0c4724df5fe418dc741362f8a5c92d5..97cd4698aac3bbf1364a958b528649e42b1447c4 100644 --- a/example/42_groupnorm/groupnorm_swish_fp16.cpp +++ b/example/42_groupnorm/groupnorm_swish_fp16.cpp @@ -1,17 +1,20 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using ComputeDataType = float; -using YElementOp = ck::tensor_operation::element_wise::Swish; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using SaveMeanInvStdDataType = float; +using ComputeDataType = float; +using YElementOp = ck::tensor_operation::element_wise::Swish; + +#define SAVE_MEAN_INV_STD using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationImpl; // OutScalarPerVector + 2, // YScalarPerVector + 1>; // SaveMeanInvStdScalarPerVector #include "run_groupnorm_example.inc" diff --git a/example/42_groupnorm/run_groupnorm_example.inc b/example/42_groupnorm/run_groupnorm_example.inc index bd7eb98ca0f8bb95b687ea3f54d227d72fb40216..89117a9b9497a5497ef344202843138a1f827310 100644 --- a/example/42_groupnorm/run_groupnorm_example.inc +++ b/example/42_groupnorm/run_groupnorm_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,6 +34,8 @@ int run_groupnorm_example(int argc, char* argv[]) Tensor y({N, H, W, G, C}); Tensor gamma({G, C}); Tensor beta({G, C}); + Tensor save_mean({N, G}); + Tensor save_inv_std({N, G}); ck::utils::FillUniformDistribution{0.f, 1.f}(x); ck::utils::FillUniformDistribution{0.f, 1.f}(gamma); @@ -43,6 +45,11 @@ int run_groupnorm_example(int argc, char* argv[]) DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); DeviceMem y_dev(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); +#ifdef SAVE_MEAN_INV_STD + DeviceMem save_mean_dev(sizeof(SaveMeanInvStdDataType) * save_mean.mDesc.GetElementSpaceSize()); + DeviceMem save_inv_std_dev(sizeof(SaveMeanInvStdDataType) * + save_inv_std.mDesc.GetElementSpaceSize()); +#endif x_dev.ToDevice(x.mData.data()); gamma_dev.ToDevice(gamma.mData.data()); @@ -57,14 +64,23 @@ int run_groupnorm_example(int argc, char* argv[]) {0, 0, 0, C, 1}, {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, + std::vector{save_mean.mDesc.GetStrides().begin(), + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), y_dev.GetDeviceBuffer(), +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.GetDeviceBuffer(), + save_inv_std_dev.GetDeviceBuffer(), +#else nullptr, nullptr, +#endif y_element_op); if(!device_instance.IsSupportedArgument(argument_ptr.get())) @@ -73,6 +89,10 @@ int run_groupnorm_example(int argc, char* argv[]) return 1; }; + size_t workspace_sz = device_instance.GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + auto invoker_ptr = device_instance.MakeInvokerPointer(); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); @@ -88,21 +108,40 @@ int run_groupnorm_example(int argc, char* argv[]) bool pass = true; { Tensor host_y({N, H, W, G, C}); - using ReferenceInstance = ck::tensor_operation::host::ReferenceGroupnorm; + Tensor host_save_mean(HostTensorDescriptor{N, G}); + Tensor host_save_inv_std(HostTensorDescriptor{N, G}); + using ReferenceInstance = + ck::tensor_operation::host::ReferenceGroupnorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, y_element_op, {N, H, W, G, C}, 1e-6); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + y_element_op, + {N, H, W, G, C}, + 1e-6); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); pass &= ck::utils::check_err(y, host_y, "Error: Incorrect results", 1e-3, 1e-3); +#ifdef SAVE_MEAN_INV_STD + save_mean_dev.FromDevice(save_mean.mData.data()); + save_inv_std_dev.FromDevice(save_inv_std.mData.data()); + pass &= ck::utils::check_err( + save_mean, host_save_mean, "Error: Incorrect results (mean)", 1e-3, 1e-3); + pass &= ck::utils::check_err( + save_inv_std, host_save_inv_std, "Error: Incorrect results (inv_std)", 1e-3, 1e-3); +#endif } return (pass ? 0 : 1); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index 7ac4b68272e2b0a1a5cbc09fa3734e190d6897da..ebba88cf41fc493103dc78b1252e6d3ad5e19739 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -8,7 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp index 764e55ef558bd977fca8f25f16b2bba4c111114b..4ab26293ccfddeb5117d409d0af9c00a303ce984 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -8,7 +8,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/utility/check_err.hpp" diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 7d6ff12eeafe8de6f45e93567959d86d748c783c..c02d540983ed17f14f3376ab158308147337fc11 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -167,20 +167,31 @@ int main() XElementwiseOperation>(x, a, b, mn, XElementwiseOperation{}); Tensor host_y(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor host_save_mean({M}); + Tensor host_save_inv_std({M}); using ReferenceInstance = ck::tensor_operation::host::ReferenceLayernorm; ReferenceInstance ref; - auto ref_argument = - ref.MakeArgument(x, gamma, beta, host_y, YElementwiseOperation{}, {M, N}, {1}, 1e-4); - auto ref_invoker = ref.MakeInvoker(); + auto ref_argument = ref.MakeArgument(x, + gamma, + beta, + host_y, + host_save_mean, + host_save_inv_std, + YElementwiseOperation{}, + {M, N}, + {1}, + 1e-4); + auto ref_invoker = ref.MakeInvoker(); ref_invoker.Run(ref_argument); y_dev.FromDevice(y.mData.data()); diff --git a/example/46_gemm_add_multiply/common.hpp b/example/46_gemm_add_multiply/common.hpp index 3ba78dfe47ba4cac2f647bf5cb609561e6989275..2c656cf44159a594f0b8ba3b40b7d8d936d57cb2 100644 --- a/example/46_gemm_add_multiply/common.hpp +++ b/example/46_gemm_add_multiply/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp index 28c3939fa611d536600bdce9c7609705e0373da1..58a399f226c1b59073dd000e382e673badb5e1fb 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_dl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index d5aa41f1b6d5a97c5c1b614e2765a86fde7343e1..56417b101d9528674535ab93f6fdc7ab36d47c3d 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index d1b3dd4be22c07f2b8fc87947328c849ff10650b..14432f6e23d4adbf09e0df6c805a68f1a5571f69 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1 +1,8 @@ -add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp index 30c98e534a99fa05ff6ebec044c5b70df7419254..a90a6340a431a55c271dca3d3d0d1771382218f5 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -121,7 +121,8 @@ using DeviceOpInstance = 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - MaskingSpec>; // MaskingSpecialization + MaskingSpec, // MaskingSpecialization + 1>; // Ref Gemm0: fp16 in, fp32 out using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" + +template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; + else if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; +}; + +template +HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + TensorLayout layout) +{ + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } +}; + +template +bool pool3d_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Z, + ck::index_t Y, + ck::index_t X, + ck::index_t Di, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_d, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t window_dilation_d, + ck::index_t window_dilation_h, + ck::index_t window_dilation_w, + ck::index_t in_left_pad_d, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_d, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + const ck::index_t Zs = (Z - 1) * window_dilation_d + 1; + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Z, Y, X}; + const std::vector window_strides{ + window_stride_d, window_stride_h, window_stride_w}; + const std::vector window_dilations{ + window_dilation_d, window_dilation_h, window_dilation_w}; + const std::vector input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w}; + + Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{})); + Tensor out_n_c_do_ho_wo_host( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_do_ho_wo_host( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_n_c_do_ho_wo_device( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_do_ho_wo_device( + f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_di_hi_wi: " << in_n_c_di_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_do_ho_wo: " << out_n_c_do_ho_wo_host.mDesc << std::endl; + + in_n_c_di_hi_wi.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_di_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_do_ho_wo_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_di_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + {N, C, Di, Hi, Wi}, + {Z, Y, X}, + {N, C, Do, Ho, Wo}, + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, InLayout{}), + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}), + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3, 4}); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "Perf: " << ave_time << std::endl; + + bool pass = true; + + if(do_verification) + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<5, + 3, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ReduceOpId, + PropagateNan, + OutputIndex>; + + auto ref_pooling = ReferencePoolingFwdInstance{}; + auto ref_pooling_invoker = ref_pooling.MakeInvoker(); + auto ref_pooling_argument = ref_pooling.MakeArgument(in_n_c_di_hi_wi, + out_n_c_do_ho_wo_host, + out_indices_n_c_do_ho_wo_host, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + + ref_pooling_invoker.Run(ref_pooling_argument); + + out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_do_ho_wo_device, out_n_c_do_ho_wo_host); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_do_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_do_ho_wo_device, + out_indices_n_c_do_ho_wo_host); + }; + } + + return (pass); +}; diff --git a/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9ac61033d5e4aa26a6f4fc20dae04d04dafa106 --- /dev/null +++ b/example/48_pool3d_fwd/pool3d_fwd_fp16.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "pool3d_fwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using OutLayout = ck::tensor_layout::convolution::NDHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool3dFwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 2; + ck::index_t C = 32; + ck::index_t Z = 2; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Di = 30; + ck::index_t Hi = 30; + ck::index_t Wi = 30; + ck::index_t window_stride_d = 2; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_d = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_d = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_d = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + bool pass = pool3d_test(do_verification, + time_kernel, + N, + C, + Z, + Y, + X, + Di, + Hi, + Wi, + window_stride_d, + window_stride_h, + window_stride_w, + window_dilation_d, + window_dilation_h, + window_dilation_w, + in_left_pad_d, + in_left_pad_h, + in_left_pad_w, + in_right_pad_d, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/CMakeLists.txt b/example/49_maxpool2d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b29cf9ccbce89d48bcff009d46b94e8c8e036741 --- /dev/null +++ b/example/49_maxpool2d_bwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) +add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) +add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a6c8f3b884ef443acf419c08cffb1636591afbc4 --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_bf16.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::bhalf_t; +using DOutDataType = ck::bhalf_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1e6693755e3581623274c323ec34a18f99c68e --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp" + +template +bool maxpool_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t window_dilation_h, + ck::index_t window_dilation_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC; // InSrcOutDstVectorSize + + using DeviceMaxPoolBwdInstance = ck::tensor_operation::device:: + DeviceMaxPoolBwdImpl; + + const ck::index_t Ys = (Y - 1) * window_dilation_h + 1; + const ck::index_t Xs = (X - 1) * window_dilation_w + 1; + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1; + + const std::vector window_spatial_lengths{Y, X}; + const std::vector window_strides{window_stride_h, window_stride_w}; + const std::vector window_dilations{window_dilation_h, window_dilation_w}; + const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; + const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; + + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { + using namespace ck::literals; + // reference need Tensor with NCHW order + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + }; + + // in + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); + + // out + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // indices + Tensor indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo)); + Tensor indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // dout + Tensor dout_n_c_ho_wo(f_host_tensor_descriptor(N, C, Ho, Wo)); + + // din + Tensor din_n_c_hi_wi_host(f_host_tensor_descriptor(N, C, Hi, Wi)); + Tensor din_n_c_hi_wi_device(f_host_tensor_descriptor(N, C, Hi, Wi)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "indices_n_c_ho_wo: " << indices_n_c_ho_wo_host.mDesc << std::endl; + std::cout << "dout_n_c_ho_wo: " << dout_n_c_ho_wo.mDesc << std::endl; + std::cout << "din_n_c_hi_wi: " << din_n_c_hi_wi_host.mDesc << std::endl; + + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + dout_n_c_ho_wo.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * + indices_n_c_ho_wo_device.mDesc.GetElementSpaceSize()); + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout_n_c_ho_wo.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + dout_device_buf.ToDevice(dout_n_c_ho_wo.mData.data()); + + auto pool_fwd = DevicePoolFwdInstance{}; + auto pool_fwd_invoker_ptr = pool_fwd.MakeInvokerPointer(); + auto pool_fwd_argument_ptr = pool_fwd.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + {N, C, Hi, Wi}, + window_spatial_lengths, + {N, C, Ho, Wo}, + {C * Hi * Wi, 1, Wi * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + {C * Ho * Wo, 1, Wo * C, C}, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {2, 3}); + + if(!pool_fwd.IsSupportedArgument(pool_fwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_fwd with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time_fwd = + pool_fwd_invoker_ptr->Run(pool_fwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + auto pool_bwd = DeviceMaxPoolBwdInstance{}; + auto pool_bwd_invoker_ptr = pool_bwd.MakeInvokerPointer(); + auto pool_bwd_argument_ptr = pool_bwd.MakeArgumentPointer( + static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + dout_n_c_ho_wo.mDesc.GetElementSpaceSize(), + din_n_c_hi_wi_device.mDesc.GetElementSpaceSize(), + window_spatial_lengths, + window_strides, + window_dilations); + + if(!pool_bwd.IsSupportedArgument(pool_bwd_argument_ptr.get())) + { + throw std::runtime_error("wrong! pool_bwd with the specified compilation parameters does " + "not support this problem"); + } + + size_t pool_bwd_workspace_sz = pool_bwd.GetWorkSpaceSize(pool_bwd_argument_ptr.get()); + DeviceMem pool_bwd_workspace_device_buf(pool_bwd_workspace_sz); + pool_bwd.SetWorkSpacePointer(pool_bwd_argument_ptr.get(), + pool_bwd_workspace_device_buf.GetDeviceBuffer()); + + float ave_time_bwd = + pool_bwd_invoker_ptr->Run(pool_bwd_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Pool fwd perf: " << ave_time_fwd << " ms" << std::endl; + std::cout << "Pool bwd perf: " << ave_time_bwd << " ms" << std::endl; + + bool pass = true; + + if(do_verification) + { + using ReferencePoolingFwdInstance = + ck::tensor_operation::host::ReferencePoolingFwd<4, + 2, + InDataType, + OutDataType, + ComputeDataType, + IndexDataType, + ck::ReduceTensorOp::MAX, + PropagateNan, + true>; + + auto ref_pooling_fwd = ReferencePoolingFwdInstance{}; + auto ref_pooling_fwd_invoker = ref_pooling_fwd.MakeInvoker(); + auto ref_pooling_fwd_argument = ref_pooling_fwd.MakeArgument(in_n_c_hi_wi, + out_n_c_ho_wo_host, + indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + ref_pooling_fwd_invoker.Run(ref_pooling_fwd_argument); + + using ReferencePoolingBwdInstance = + ck::tensor_operation::host::ReferenceMaxPoolBwd; + + auto ref_pooling_bwd = ReferencePoolingBwdInstance{}; + auto ref_pooling_bwd_invoker = ref_pooling_bwd.MakeInvoker(); + auto ref_pooling_bwd_argument = ref_pooling_bwd.MakeArgument( + dout_n_c_ho_wo, indices_n_c_ho_wo_host, din_n_c_hi_wi_host, PassThrough{}); + + ref_pooling_bwd_invoker.Run(ref_pooling_bwd_argument); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + indices_device_buf.FromDevice(indices_n_c_ho_wo_device.mData.data()); + din_device_buf.FromDevice(din_n_c_hi_wi_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device, out_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(indices_n_c_ho_wo_device, indices_n_c_ho_wo_host); + pass = pass && ck::utils::check_err(din_n_c_hi_wi_device, din_n_c_hi_wi_host); + } + + return (pass); +}; diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4f982d855bbddd7a95d683d320e00bb12da0c3d --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp16.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = ck::half_t; +using DOutDataType = ck::half_t; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 1; + ck::index_t window_stride_w = 1; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c14928a9367fc69f44260a88ca1200944f76e44e --- /dev/null +++ b/example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/reduction_enums.hpp" + +#include "maxpool2d_bwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using IndexDataType = int32_t; +using ComputeDataType = float; +using DInDataType = float; +using DOutDataType = float; + +static constexpr bool PropagateNan = false; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // Pool shape + ck::index_t N = 1; + ck::index_t C = 1; + ck::index_t Y = 2; + ck::index_t X = 2; + ck::index_t Hi = 32; + ck::index_t Wi = 32; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t window_dilation_h = 1; + ck::index_t window_dilation_w = 1; + ck::index_t in_left_pad_h = 0; + ck::index_t in_left_pad_w = 0; + ck::index_t in_right_pad_h = 0; + ck::index_t in_right_pad_w = 0; + + bool pass = maxpool_bwd_test(do_verification, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + window_dilation_h, + window_dilation_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b0020ebcf65a87ad39b004c149385cdd7de89ae --- /dev/null +++ b/example/50_put_element/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/50_put_element/put_element_fp16.cpp b/example/50_put_element/put_element_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..791747d8756962cb2e24c606d4d9f50ba3ee3410 --- /dev/null +++ b/example/50_put_element/put_element_fp16.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using XDataType = ck::half_t; +using YDataType = ck::half_t; +using IndexDataType = int32_t; + +using YElementwiseOp = ck::tensor_operation::element_wise::PassThrough; + +using DeviceInstance = + ck::tensor_operation::device::DevicePutElementImpl; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + int N = 1024; + + Tensor x(HostTensorDescriptor{N}); + Tensor indices(HostTensorDescriptor{N}); + Tensor y(HostTensorDescriptor{N}); + + x.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + for(int i = 0; i < N; ++i) + indices(i) = i; + + DeviceMem x_device_buf(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); + DeviceMem y_device_buf(sizeof(YDataType) * y.mDesc.GetElementSpaceSize()); + DeviceMem indices_device_buf(sizeof(IndexDataType) * indices.mDesc.GetElementSpaceSize()); + + x_device_buf.ToDevice(x.mData.data()); + indices_device_buf.ToDevice(indices.mData.data()); + + auto put_instance = DeviceInstance{}; + auto put_invoker_ptr = put_instance.MakeInvokerPointer(); + auto put_argument_ptr = put_instance.MakeArgumentPointer( + static_cast(x_device_buf.GetDeviceBuffer()), + static_cast(indices_device_buf.GetDeviceBuffer()), + static_cast(y_device_buf.GetDeviceBuffer()), + N, + N, + YElementwiseOp{}); + + if(!put_instance.IsSupportedArgument(put_argument_ptr.get())) + { + throw std::runtime_error("argument is not supported!"); + } + + float ave_time = + put_invoker_ptr->Run(put_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + Tensor y_host(HostTensorDescriptor{N}); + + for(int i = 0; i < N; ++i) + { + IndexDataType idx = indices(i); + y_host(idx) = x(i); + } + + y_device_buf.FromDevice(y.mData.data()); + pass = ck::utils::check_err(y, y_host); + } + + return (pass ? 0 : 1); +} diff --git a/example/51_avgpool3d_bwd/CMakeLists.txt b/example/51_avgpool3d_bwd/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..fef0c66835a804e943346f099481c98336a96b4d --- /dev/null +++ b/example/51_avgpool3d_bwd/CMakeLists.txt @@ -0,0 +1,3 @@ +add_example_executable(example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp) +add_example_executable(example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp) +add_example_executable(example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp) diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a5ab6ef24a4da1cb23719719a4c6d9cb8187ded3 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::bhalf_t; +using DInDataType = ck::bhalf_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..394f046b1e9d2136e42001f366d7423f1f5e8600 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp" + +template +std::vector f_tensor_strides_ncdhw(ck::index_t N_, + ck::index_t C_, + ck::index_t D, + ck::index_t H, + ck::index_t W, + TensorLayout layout) +{ + using namespace ck::literals; + (void)N_; + if constexpr(ck::is_same::value) + return {C_ * D * H * W, D * H * W, H * W, W, 1_uz}; + else if constexpr(ck::is_same::value) + return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}; +}; + +template +HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, + std::size_t C_, + std::size_t D, + std::size_t H, + std::size_t W, + TensorLayout layout) +{ + using namespace ck::literals; + + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor({N_, C_, D, H, W}, + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + } +}; + +template +bool pool3d_bwd_test(bool do_verification, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Di, + ck::index_t Hi, + ck::index_t Wi, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector dinput_left_pads, + std::vector dinput_right_pads) +{ + auto OutSpatialLength = [&](auto InSpatialLength, int index) { + ck::index_t left_pad = dinput_left_pads[index]; + ck::index_t right_pad = dinput_right_pads[index]; + ck::index_t window_len = window_lengths[index]; + ck::index_t stride = window_strides[index]; + ck::index_t dilation = window_dilations[index]; + ck::index_t eff = (window_len - 1) * dilation + 1; + return (InSpatialLength + left_pad + right_pad - eff) / stride + 1; + }; + + ck::index_t Do = OutSpatialLength(Di, 0); + ck::index_t Ho = OutSpatialLength(Hi, 1); + ck::index_t Wo = OutSpatialLength(Wi, 2); + + Tensor dout(f_host_tensor_descriptor(N, C, Do, Ho, Wo, DOutLayout{})); + Tensor din_dev(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + Tensor din_host(f_host_tensor_descriptor(N, C, Di, Hi, Wi, DInLayout{})); + + std::cout << "dout: " << dout.mDesc << std::endl; + std::cout << "din_host: " << din_host.mDesc << std::endl; + + dout.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem dout_device_buf(sizeof(DOutDataType) * dout.mDesc.GetElementSpaceSize()); + DeviceMem din_device_buf(sizeof(DInDataType) * din_dev.mDesc.GetElementSpaceSize()); + + dout_device_buf.ToDevice(dout.mData.data()); + din_device_buf.SetZero(); + + auto pool = DevicePoolBwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = + pool.MakeArgumentPointer(static_cast(dout_device_buf.GetDeviceBuffer()), + static_cast(din_device_buf.GetDeviceBuffer()), + {N, C, Do, Ho, Wo}, + {N, C, Di, Hi, Wi}, + f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, DOutLayout{}), + f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, DInLayout{}), + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + std::cout << "Perf: " << ave_time << std::endl; + + bool pass = true; + + if(do_verification) + { + auto ref_pool = + ck::tensor_operation::host::ReferenceAvgPoolBwd<3, DInDataType, DOutDataType>(); + + auto ref_invoker = ref_pool.MakeInvoker(); + + auto ref_argument = ref_pool.MakeArgument(din_host, + dout, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); + + ref_invoker.Run(ref_argument); + + din_device_buf.FromDevice(din_dev.mData.data()); + pass = ck::utils::check_err(din_dev, din_host); + } + + return pass; +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..578f563d5c29e8818f9319e7fb79920a785f58c5 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = ck::half_t; +using DInDataType = ck::half_t; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2910c55679ca8fefce81cc0c07b7d6a439af182 --- /dev/null +++ b/example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp" + +#include "avgpool3d_bwd_common.hpp" + +using DOutDataType = float; +using DInDataType = float; +using ComputeDataType = float; + +#if 1 +using DOutLayout = ck::tensor_layout::convolution::NDHWC; +using DInLayout = ck::tensor_layout::convolution::NDHWC; +#else +using DOutLayout = ck::tensor_layout::convolution::NCDHW; +using DInLayout = ck::tensor_layout::convolution::NCDHW; +#endif + +using DevicePoolBwdInstance = + ck::tensor_operation::device::DeviceAvgPool3dBwd_NDHWC_NDHWC; // InSrcOutDstVectorSize + +int main() +{ + std::vector window_lengths = {5, 5, 5}; + std::vector window_strides = {2, 2, 2}; + std::vector window_dilations = {2, 2, 2}; + std::vector dinput_left_pads = {0, 0, 0}; + std::vector dinput_right_pads = {0, 0, 0}; + + ck::index_t N = 1; + ck::index_t C = 16; + ck::index_t Di = 40; + ck::index_t Hi = 40; + ck::index_t Wi = 40; + + pool3d_bwd_test( + true, + false, + N, + C, + Di, + Hi, + Wi, + window_lengths, + window_strides, + window_dilations, + dinput_left_pads, + dinput_right_pads); +} diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4dc6c8b4e0344f6d43e3d8d478fb07d4abcdcc13 --- /dev/null +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -0,0 +1,15 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_im2col_col2im) + + add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) + add_example_dependencies(example_im2col_col2im example_image_to_column_f32) + + add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) + add_example_dependencies(example_im2col_col2im example_column_to_image_f32) + + set(target 1) + endif() +endforeach() diff --git a/example/52_im2col_col2im/column_to_image_f32.cpp b/example/52_im2col_col2im/column_to_image_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..047f2a21186be1223cfc7a7727c64045310be987 --- /dev/null +++ b/example/52_im2col_col2im/column_to_image_f32.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = FP32; // ck::bhalf_t;//FP32; +using OutDataType = FP32; // ck::bhalf_t;//FP32; + +using ImLayout = ck::tensor_layout::convolution::GNHWC; +using ColumnToImageOp = ck::conv_tensor_rearrange_op::ColumnToImage; + +// clang-format off +using DeviceColToImgInstance = ck::tensor_operation::device::DeviceColumnToImageImpl + //#####################| Num| ImLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, ImLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>; +// clang-format on + +bool RunColumnToImage(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + const auto out_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), gemm_g_m_k_strides); + copy(out_desc.GetStrides(), image_g_n_c_wis_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{1, 2}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto col2img = DeviceColToImgInstance{}; + auto invoker = col2img.MakeInvoker(); + auto argument = col2img.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!col2img.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_column_to_image = ck::tensor_operation::host:: + ReferenceColumnToImage(); + + auto ref_invoker = ref_column_to_image.MakeInvoker(); + + auto ref_argument = ref_column_to_image.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_column_to_image.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_col2img with the specified compilation parameters does " + "not support this col2img problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(out_device.mData.data()); + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunColumnToImageExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunColumnToImage(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunColumnToImageExample(argc, argv); } diff --git a/example/52_im2col_col2im/common.hpp b/example/52_im2col_col2im/common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..61d30c4cb44abfcdbebffd88aebb27d9783f88c5 --- /dev/null +++ b/example/52_im2col_col2im/common.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_image_to_column.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_column_to_image.hpp" + +template +using S = ck::Sequence; + +static inline constexpr ck::index_t NDimSpatial = 2; + +using FP32 = float; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +#define DefaultConvParams \ + ck::utils::conv::ConvParam \ + { \ + NDimSpatial, 1, 32, 1, 1, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_params) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + config = ExecutionConfig{}; + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_params = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} diff --git a/example/52_im2col_col2im/image_to_column_f32.cpp b/example/52_im2col_col2im/image_to_column_f32.cpp new file mode 100644 index 0000000000000000000000000000000000000000..140bdf8062d1ad91d02f84b4056d4b446011e632 --- /dev/null +++ b/example/52_im2col_col2im/image_to_column_f32.cpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +using InDataType = FP32; +using OutDataType = FP32; + +using ImLayout = ck::tensor_layout::convolution::GNHWC; +using ImageToColumnOp = ck::conv_tensor_rearrange_op::ImageToColumn; + +// clang-format off +using DeviceImgToColInstance = ck::tensor_operation::device::DeviceImageToColumnImpl + //#####################| Num| ImLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| + //#####################| Dim| | | | Size| Block| Block| Cluster| Per| + //#####################| Spatial| | | | | | | Lengths| Vector| + //#####################| | | | | | | | | | + < NDimSpatial, ImLayout, InDataType, OutDataType, 256, 128, 128, S<16, 16>, 1>; +// clang-format on + +bool RunImageToColumn(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_params) +{ + const auto G = conv_params.G_; + const auto N = conv_params.N_; + const auto C = conv_params.C_; + + const ck::index_t NDoHoWo = + N * ck::accumulate_n( + conv_params.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const ck::index_t CZYX = + C * ck::accumulate_n( + conv_params.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto in_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_params); + const auto out_desc = HostTensorDescriptor({G, NDoHoWo, CZYX}); + + std::array input_spatial_lengths{}; + std::array filter_spatial_lengths{}; + std::array output_spatial_lengths{}; + std::array image_g_n_c_wis_strides{}; + std::array gemm_g_m_k_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + + copy(conv_params.input_spatial_lengths_, input_spatial_lengths); + copy(conv_params.filter_spatial_lengths_, filter_spatial_lengths); + copy(conv_params.output_spatial_lengths_, output_spatial_lengths); + copy(in_desc.GetStrides(), image_g_n_c_wis_strides); + copy(out_desc.GetStrides(), gemm_g_m_k_strides); + copy(conv_params.conv_filter_strides_, conv_filter_strides); + copy(conv_params.conv_filter_dilations_, conv_filter_dilations); + copy(conv_params.input_left_pads_, input_left_pads); + copy(conv_params.input_right_pads_, input_right_pads); + + Tensor in(in_desc); + Tensor out_device(out_desc); + Tensor out_host(out_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "out: " << out_device.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + + // reset input to zero + out_device_buf.SetZero(); + + static_assert(std::is_default_constructible_v); + + // do conv + auto img2col = DeviceImgToColInstance{}; + auto invoker = img2col.MakeInvoker(); + auto argument = img2col.MakeArgument(in_device_buf.GetDeviceBuffer(), + out_device_buf.GetDeviceBuffer(), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + if(!img2col.IsSupportedArgument(argument)) + { + std::cerr << "wrong! device_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + + return false; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + std::size_t num_btype = G * NDoHoWo * CZYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + if(config.do_verification) + { + auto ref_image_to_column = ck::tensor_operation::host:: + ReferenceImageToColumn(); + + auto ref_invoker = ref_image_to_column.MakeInvoker(); + + auto ref_argument = ref_image_to_column.MakeArgument(in, + out_host, + conv_params.filter_spatial_lengths_, + conv_params.conv_filter_strides_, + conv_params.conv_filter_dilations_, + conv_params.input_left_pads_, + conv_params.input_right_pads_); + + if(!ref_image_to_column.IsSupportedArgument(&ref_argument)) + { + std::cerr << "wrong! ref_img2col with the specified compilation parameters does " + "not support this img2col problem" + << std::endl; + return false; + } + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device.mData, out_host.mData); + } + + return true; +} + +int RunImageToColumnExample(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_params = DefaultConvParams; + + if(!parse_cmd_args(argc, argv, config, conv_params)) + { + return EXIT_FAILURE; + } + + if(conv_params.num_dim_spatial_ != NDimSpatial) + { + std::cerr << "unsupported # of spatial dimensions" << std::endl; + return EXIT_FAILURE; + } + + return !RunImageToColumn(config, conv_params); +} + +int main(int argc, char* argv[]) { return RunImageToColumnExample(argc, argv); } diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..57bc0b33ef0aaf22cd9613e9612bf6a2ae593c51 --- /dev/null +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -0,0 +1,8 @@ +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13ba6398140793627420271257ae9ca061c9302b --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..42500b64e69fb8345002498b281e7fe5f71fc575 --- /dev/null +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -0,0 +1,8 @@ +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6634dacd257281708be437e401a84e432b503601 --- /dev/null +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using A0DataType = F16; +using A1DataType = F32; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +struct Multiply +{ + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const + { + a = ck::type_convert(ck::type_convert(a0) * a1); + } +}; + +using AElementOp = Multiply; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< + NumDimM, + NumDimN, + NumDimK, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + float alpha = 1.0f; + float beta = 1.0f; + + // A0[M0, M1, K0, K1] + std::vector a0_ms_ks_lengths{30, 128, 32, 64}; + std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // A1[M1, K1] -> A1[M0, M1, K0, K1] + std::vector a1_ms_ks_lengths{30, 128, 32, 64}; + std::vector a1_ms_ks_strides{0, 64, 0, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; + std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_ms_ks.mData.data()); + a1_device_buf.ToDevice(a1_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, + std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, + std::array, 1>{b_ns_ks_lengths}, + std::array, 1>{b_ns_ks_strides}, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_contraction with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + { + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + } + + if(do_verification) + { + + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + + for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) + { + for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) + { + for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) + { + a_element_op(a_ms_ks(m0, m1, k0, k1), + a0_ms_ks(m0, m1, k0, k1), + a1_ms_ks(m0, m1, k0, k1)); + } + } + } + } + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea38216fa944859b379b68f324d8a1988a519d6e --- /dev/null +++ b/example/62_conv_fwd_activ/CMakeLists.txt @@ -0,0 +1,35 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_fwd_activ_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) + add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp b/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..185026b1e3b70f40876393571967833e12417ab9 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +bool run_grouped_conv_fwd(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4fe0c857fa4f75e52b97762cce85b7b4ccafec74 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..feabacc5c963de61589ee477e7c0563c2aaed0a1 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..793102dbc60346c011c034800e357bade9116803 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Elu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a77408db7e916cc6fdd8d873cc43b663e7449007 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2b695cf8c3d46000ddd1f6c2c22582d2fa50ce24 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Power; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e1b6e3f0ccbcccd5b8cb2ad498dd11f6f1ddf025 --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Relu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..350c15a787aee8f8e4374872bb5f6ca276d7a01f --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ec52e1a3c4e4ad3f9b289d6b23c59f109d5e325d --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dca405669a79eb70f0f1e42113f6dcfee3cc034e --- /dev/null +++ b/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::TanH; + +using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +#include "run_convnd_fwd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc new file mode 100644 index 0000000000000000000000000000000000000000..7c20c010662510ba2129f4177376fffe745807a8 --- /dev/null +++ b/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +void print_helper_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +bool run_convnd_fwd_example(int argc, char* argv[]) +{ + print_helper_msg(); + + bool do_verification = true; + // Use floats for SoftRelu by default to avoid overflow after e^x. + int init_method = + std::is_same_v ? 2 : 1; + bool time_kernel = false; + + // Following shapes are selected to avoid overflow. Expect inf in case of + // size increase for some elementwise ops. + ck::utils::conv::ConvParam conv_param{ + 3, 1, 16, 128, 8, {3, 3, 3}, {17, 17, 17}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}; + + if(argc == 1) + { + // use default + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + + conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv); + } + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto run = [&]() { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + return run_grouped_conv_fwd(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); + }; + + if(conv_param.num_dim_spatial_ == 3) + { + return run(); + } + + return false; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1fdd2f6d14826776b734b76815e192b4654e5b09..c19ba93b69ff635daea2b8d8008971b1b03a0402 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -7,20 +7,112 @@ add_custom_target(examples) function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - add_dependencies(examples ${EXAMPLE_NAME}) - add_dependencies(check ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable EXAMPLE_NAME) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(result EQUAL 0) + add_dependencies(${EXAMPLE_NAME} ${FILE_NAME}) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_dependencies(examples ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message("removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index cb20ea2492d53a8b9ca8e6756584cec4539767d1..1e41404192a3021410f76d07d6eb7aab45be8233 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/config.h" + #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -27,11 +29,27 @@ #define CK_WAVELET_MIN_BLOCK_PER_CU 2 #endif +// kernel attribute: amdgpu_waves_per_eu() +#ifdef CK_USE_WAVES_PER_EU +// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute +#ifndef CK_MIN_WAVES_PER_EU +#define CK_MIN_WAVES_PER_EU 0 +#endif + +#ifndef CK_MAX_WAVES_PER_EU +#define CK_MAX_WAVES_PER_EU 0 +#endif + +#else +#define CK_USE_WAVES_PER_EU 0 +#endif + // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx940__) // for GPU code + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx1030__) // for GPU code #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -44,24 +62,29 @@ #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 #elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__) || \ - defined(__gfx940__) // for GPU code + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 +#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 #endif // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_MFMA #endif -#if(defined(__gfx90a__) || defined(__gfx940__)) +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) #define CK_USE_AMD_MFMA_BF16_1K_OP #endif -#if defined(__gfx940__) +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define CK_USE_AMD_MFMA_GFX940 #endif @@ -84,13 +107,15 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) // for GPU code +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif -#if(defined(__gfx90a__) || defined(__gfx940__)) // for GPU code +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 #else #define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 @@ -99,8 +124,15 @@ // inline asm #define CK_USE_AMD_INLINE_ASM 1 -// inner product (DLOP) -#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1 +// inner product (V_MAC/V_FMAC) +#define CK_USE_AMD_V_MAC_INLINE_ASM 1 + +// V_DOT inline instructions, less efficient since they require adding +// `s_nop`s to avoid hazard +#define CK_USE_AMD_V_DOT_INLINE_ASM 0 + +// inner product using V_DOT with DPP8 modifiers +#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 // block synchronization only s_wait lgkmcnt(0), not vmcnt(0) #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 @@ -144,6 +176,10 @@ #define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 // experimental feature: add instances using pipeline v2 #define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 +// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy) +#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT +#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0 +#endif // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be @@ -169,13 +205,17 @@ // workaround: compiler issue on gfx908 #define CK_WORKAROUND_SWDEV_388832 1 + // flag to enable (1) or disable (0) the debugging output in some kernels #define DEBUG_LOG 0 // denorm test fix, required to work around dissue #ifndef CK_WORKAROUND_DENORM_FIX #define CK_WORKAROUND_DENORM_FIX 0 -#endif +#elif +// enable only on MI200 +#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#endif // CK_WORKAROUND_DENORM_FIX namespace ck { diff --git a/include/ck/config.h.in b/include/ck/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..1748344756771a2d57dfd044d79f8c72f26f8fb4 --- /dev/null +++ b/include/ck/config.h.in @@ -0,0 +1,109 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +#cmakedefine DTYPES "@DTYPES@" +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@ +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@ +#endif + +#ifndef CK_ENABLE_FP8 +#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ +#endif + +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + +#ifndef CK_ENABLE_FP16 +#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ +#endif + +#ifndef CK_ENABLE_BF16 +#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@ +#endif + +#ifndef CK_ENABLE_FP32 +#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ +#endif + +#ifndef CK_ENABLE_FP64 +#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ +#endif + +// +// Instances supports in the current CK build +// +#ifndef CK_ENABLE_INSTANCES_ONLY +#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index e2cbdb733272d37aa5dbc7a746e86911d6b8644f..be1dbc1657e5bf2ac15d37184175c16b79f0b8eb 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -51,4 +51,11 @@ inline std::string get_device_name() return name; } +inline bool is_xdl_supported() +{ + return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942"; +} + } // namespace ck diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index d3dc8eaf1eb8b87207256ea4a521e23d5a49ca9c..3e44faecb692b4d509d654379294277eb99d6473 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -1,10 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include #include +// To be removed, which really does not tell the location of failed HIP functional call inline void hip_check_error(hipError_t x) { if(x != hipSuccess) @@ -15,3 +17,16 @@ inline void hip_check_error(hipError_t x) throw std::runtime_error(ss.str()); } } + +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ + << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while(0) diff --git a/include/ck/host_utility/io.hpp b/include/ck/host_utility/io.hpp index ac8719592db8b48996bfbc65dc656d6d96bde545..55734bab2e469a12cd02c417cbb25d9c8a49727c 100644 --- a/include/ck/host_utility/io.hpp +++ b/include/ck/host_utility/io.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 24f2121674c7d371a7d13c0b1c83e7f0a5e059be..df311620e1da358c4409548778dcb0cdb131fb5d 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -34,6 +34,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #endif // warm up kernel<<>>(args...); + hip_check_error(hipGetLastError()); const int nrepeat = 10; #if DEBUG_LOG @@ -50,6 +51,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, for(int i = 0; i < nrepeat; ++i) { kernel<<>>(args...); + hip_check_error(hipGetLastError()); } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); @@ -64,11 +66,86 @@ float launch_and_time_kernel(const StreamConfig& stream_config, else { kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; } #else kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { +#if DEBUG_LOG + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up 1 time\n"); +#endif + // warm up + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + const int nrepeat = 10; +#if DEBUG_LOG + printf("Start running %d times...\n", nrepeat); +#endif + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + return total_time / nrepeat; + } + else + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(args...); + hip_check_error(hipGetLastError()); return 0; #endif diff --git a/include/ck/host_utility/stream_utility.hpp b/include/ck/host_utility/stream_utility.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9ab49489bbe59b5152ad61c74622be48c33c28a7 --- /dev/null +++ b/include/ck/host_utility/stream_utility.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +static inline int getAvailableComputeUnitCount(const StreamConfig& stream_config) +{ + constexpr int MAX_MASK_DWORDS = 64; + + // assume at most 64*32 = 2048 CUs + uint32_t cuMask[MAX_MASK_DWORDS]; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + cuMask[i] = 0; + + auto countSetBits = [](uint32_t dword) { + int count = 0; + + while(dword != 0) + { + if(dword & 0x1) + count++; + + dword = dword >> 1; + }; + + return (count); + }; + + hip_check_error(hipExtStreamGetCUMask(stream_config.stream_id_, MAX_MASK_DWORDS, &cuMask[0])); + + int ret = 0; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + ret += countSetBits(cuMask[i]); + + return (ret); +}; diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index db8e48df6d4b18e9c8a60b8f3d4382d1c6e649f8..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,275 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// Number of GEMMs = YTilde * XTilde -// GemmM = C -// GemmN = N * HTildeSlice * WTildeSlice -// GemmK = K * YDotSlice * XDotSlice -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - Number, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - constexpr auto IYTilde = Number{}; - constexpr auto IXTilde = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - - const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - - // only work on HTilde and WTilde that contribute to non-padding area of input tensor - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - - const auto IHTildeSliceEnd = - math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildeSliceEnd = - math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilde, XTilde); - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilde), - make_freeze_transform(IXTilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - -#if 1 - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // output tensor - // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - -#if 1 - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IXTilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_pass_through_transform(C), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 5391b595b5c1e9b5b7395d1a676efcd87b9f2e6e..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,355 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: out -// B: wei -// C: in -// Number of GEMMs = YTilde * XTilde -// GemmM = N * HTildeSlice * WTildeSlice -// GemmN = C -// GemmK = K * YDotSlice * XDotSlice -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - IYTilde i_ytilde, - IXTilde i_xtilde, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - const auto YTilde = ConvStrideH / GcdStrideDilationH; - const auto XTilde = ConvStrideW / GcdStrideDilationW; - - const auto YDot = math::integer_divide_ceil(Y, YTilde); - const auto XDot = math::integer_divide_ceil(X, XTilde); - - const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); - const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - - // only work on HTilde and WTilde that contribute to non-padding area of input tensor - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); - - const auto IHTildeSliceEnd = - math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); - const auto IWTildeSliceEnd = - math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); - - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // A: output tensor - // this add padding check - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - -#if 1 - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(K1)), - make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // B: weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - -#if 1 - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#else - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), - make_pass_through_transform(C), - make_pass_through_transform(K1)), - make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); -#endif - - // C: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -// A: out -// B: wei -// C: in -// Number of GEMMs = 1 -// GemmM = N * Ho * Wo -// GemmN = C -// GemmK = K -template -__host__ __device__ constexpr auto -transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const ConvStrides& conv_strides, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto K1 = GemmK1; - const auto K0 = K / K1; - - // A: output tensor - const auto out_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), - make_unmerge_transform(make_tuple(K0, K1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); - - // B: weight tensor - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: input tensor - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp deleted file mode 100644 index bb1dc239f4d7e6a5abade2dda4a17a98ab7d1a47..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmK = N * Ho * Wo -// GemmN = C * Y * X -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = C * Y * X; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index ca530934e49dcccea388fd8501e5bf046f1c94af..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmK = N * Ho * Wo -// GemmN = C * Y * X -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = C * Y * X; - const auto GemmK = N * Ho * Wo; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index e960f90c4bb1974ddfaaf9faf6a12d75d1db2f4d..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,150 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: output tensor - const auto out_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 052bab423db04740c88721026a3cceb3fa77c292..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = Y * X * C; - const auto GemmN = K; - const auto GemmK = N * Ho * Wo; - const auto GemmK0 = GemmK / GemmK1; - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B: output tensor - const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - out_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index c301a9e0c67b11c42fbe2a0a99c4f764049a5222..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,147 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: out -// B: in -// C: wei -// GemmM = K -// GemmN = Y * X * C -// GemmKTotal = N * Ho * Wo -template -__host__ __device__ constexpr auto -transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number, - GemmKBatchType GemmKBatch, - GemmKPadType GemmKPad) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = Y * X * C; - const auto GemmKTotal = N * Ho * Wo; - const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); - - // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // B: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmktotal_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp index 41267536551ea298e111f6d93b8e6f3f8f9ed475..6b118e972e67897f2c9e0cad3a8f959760f82f6c 100644 --- a/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP #define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 381f9ac9d6f5a9b000eb44839c7198e4dc9f5d9b..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,260 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_global_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1( - const TensorDescriptor& wei_k_c_y_x_global_desc, - const TensorDescriptor& in_n_c_hi_wi_global_desc, - const TensorDescriptor& out_n_k_ho_wo_global_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); - - const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); - const auto X = wei_k_c_y_x_global_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index ebfaabb03ebdda41d086e4b80c5856156810992a..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,179 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple( - wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); -} - -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && - ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && - InRightPadW == 0); - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // input tensor - const auto in_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple( - wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 6e576d69f5f4b5a9a60a05f84edb43fec990e91a..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), - make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 13e1bf251abfdc3716bdd17f331f82fd7bdfad41..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,132 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = K; - const auto GemmN = N * Ho * Wo; - const auto GemmK = C * Y * X; - const auto GemmK0 = GemmK / GemmK1; - - // weight tensor - const auto wei_gemmk_gemmm_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmn_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 088d14b2ee472bc049385ad27e45d02df1e3b6c9..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,134 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// A: in -// B: wei -// C: out -// GemmM = N * Ho * Wo -// GemmN = K -// GemmK = Y * X * C -template -__host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk( - const TensorDescriptor& in_n_hi_wi_c_grid_desc, - const TensorDescriptor& wei_k_y_x_c_grid_desc, - const TensorDescriptor& out_n_ho_wo_k_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - Number) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto GemmK1 = Number{}; - - const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - - const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - - const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - - const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); - const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto GemmM = N * Ho * Wo; - const auto GemmN = K; - const auto GemmK = Y * X * C; - const auto GemmK0 = GemmK / GemmK1; - - // A: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmm_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = - transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B: weight tensor - const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = - transform_tensor_descriptor(wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp b/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index a6785d56df7e70fdb4882f0c431b886d8c87caac..0000000000000000000000000000000000000000 --- a/include/ck/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM0 = 1 -// GemmM1 = K -// GemmN0 = N0 -// GemmN1 = (N / N0) * Ho * Wo -// GemmK0 = (C / C0) * Y * X -// GemmK1 = C0 -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( - const TensorDescriptor& wei_k_c_y_x_grid_desc, - const TensorDescriptor& in_n_c_hi_wi_grid_desc, - const TensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const N0Type& N0, - const C0Type& C0) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - const auto N1 = N / N0; - const auto C1 = C / C0; - - // weight tensor - const auto wei_gk0_gm0_gm1_gk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), - make_tuple(make_unmerge_transform(make_tuple(I1, K)), - make_unmerge_transform(make_tuple(C0, C1 * Y * X))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_unmerge_transform(make_tuple(C0, C1)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); - - const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_tensor_descriptor( - in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C1, Y, X)), - make_pass_through_transform(N0), - make_merge_transform(make_tuple(N1, Ho, Wo)), - make_pass_through_transform(C0)), - make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // output tensor - const auto out_n_k_howo_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)); - - const auto out_n0_n1_1_k_howo_grid_desc = - transform_tensor_descriptor(out_n_k_howo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_unmerge_transform(make_tuple(I1, K)), - make_pass_through_transform(Ho * Wo)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_tensor_descriptor( - out_n0_n1_1_k_howo_grid_desc, - make_tuple(make_pass_through_transform(I1), - make_pass_through_transform(K), - make_pass_through_transform(N0), - make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), - make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - return make_tuple( - wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); -} - -} // namespace ck -#endif diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp index 70ca34555a01436c79ba244ac03572bb4e9520b4..505a602b240428bcd1f0f81018fef0c4716c80b7 100644 --- a/include/ck/stream_config.hpp +++ b/include/ck/stream_config.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor/static_tensor.hpp b/include/ck/tensor/static_tensor.hpp index fee679f91060aca623ab498362de78fd8118fee2..d719ef9760d79297600d7524167eba78cd137831 100644 --- a/include/ck/tensor/static_tensor.hpp +++ b/include/ck/tensor/static_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_STATIC_TENSOR_HPP #define CK_STATIC_TENSOR_HPP diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index 0c9ea2ff2a0d73b793008a954eaf7293b33ade08..2dfcad8e042e548d15b4bd963fe615e39c04eff1 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index 4e4d7593e9083f321f1cabe65faaee3c14259666..ae3139ce78c8b3b881ee36602011dd895532bd83 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1042,13 +1042,13 @@ struct Merge_v2_magic_division using UpLengths = decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); - using LowLengthsMagicDivisorMultipiler = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, - Number{})); + using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); - using LowLengthsMagicDivisorShift = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, - Number{})); + using LowLengthsMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); LowLengths low_lengths_; LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; @@ -1201,9 +1201,9 @@ struct Merge_v2r2_magic_division lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, Number{})); - using LowLengthsScanMagicDivisorShift = decltype( - generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, - Number{})); + using LowLengthsScanMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); LowLengths low_lengths_; LowLengthsScan low_lengths_scan_; diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index 044a90370095eb53b94b5c2fba81abdbeae82c00..af0a8a34d0e48bde494dacfff5bbb17eda5eb479 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index d42e0a6ff08f60eb1cf6e0f241f93e528bf3514b..3ffac32469a8d97d678bc6b5fe62a7cc5e0b24ab 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f07d5b1733d9cc96dfac9f18cbbe30510760cabc..f1df2eedd466c81e4b7938c4808075869e6309fd 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index 461aae72cf7b1879c90e473adb3814e1c4875b52..f3ac041bf9b8dc03802d170fbd7bf08ce6ab9cb1 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 17c9100b9fd76418d562c5c175ad24062d8c3415..9a326092d2e0fd8392ec42a8c0a82b4167076373 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 8b1b7be11ef7b3cd19f11bd400b8a96146921bf7..f23404a1d7ba81f9c65bf111da51701f2814d4bc 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,7 @@ namespace ck { // C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] -// A and B are visable to the whole block, C is distributed among each thread +// A and B are visible to the whole block, C is distributed among each thread // Assume: // 1. A: // 1. ABlockDesc_BK0_BM_BK1 is known at compile-time diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp index 33120bd86ff01da47f7b80c593bc966785cb3711..b0143366c1de5f8e6ec52cceefd3c20e4e77cbb7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp index f45655721fe4de11b50b91e0dc5d22790ff73bea..0d092da5168c5d5629eecc2aac7cbdd53e210277 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP #define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d62ed4b15dd3a6ea078ef7ed0b70e818a30af611 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index d75f37d7b3915408b05c1bc77b478ab3e41dc8b0..b3d45f3d0c843c5e72a37b4a9a7617f2f1abc16a 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); + } + else + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } } protected: diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 5328dfde9bc09039af03892467d94545002c56fd..904a96cc9f8d226db37ca469bb71081ae20d42b4 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1,30 +1,16 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) @@ -42,7 +28,8 @@ MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) } template {}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -308,9 +295,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -332,25 +319,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); }); }); @@ -370,8 +359,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -380,8 +369,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -399,7 +388,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -493,20 +485,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -528,8 +522,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 // TODO: insert setprio in more precise manner since we // could have more than >1 MFMA instructions in single call xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), c_thread_buf.GetVectorTypeReference(Number{})); if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) { @@ -555,8 +549,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, I1, I1, Number{})); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -565,8 +559,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1, A_K1>; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -582,7 +576,8 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }; template {}.K0PerXdlops, - index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -668,7 +666,8 @@ struct BlockwiseGemmXdlops_v2 static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index aa814ab00939d0c36866af01b6cf2057dcbd5121..8ae1ba3f34c10d8359b470ed1172f1b70a7fa8b5 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp index 7e62a822a8f042d238905295b7f88c1d9bdd88f8..2fb7242708da58828b1d6f5c6b78ea472783814f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -35,8 +35,8 @@ struct BlockwiseSoftmax static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); - using ThreadSliceDesc_M = decltype( - make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); + using ThreadSliceDesc_M = decltype(make_naive_tensor_descriptor_packed( + make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); using ThreadwiseMaxReduce = typename conditional< IgnoreNaN, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 03e4d42d3a1f9eb1b180c95368b905e619e67110..d8da134a3415a6976b27d0a4fdd7f13798d0245b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp index 316508651e4bd4dad2e47edf020a4145176fa753..820a08fc482b44a3509188db0f98aecabb692a82 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" namespace ck { @@ -35,10 +35,11 @@ struct BlockwiseWelford static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + template __device__ static inline void - Merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) + Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b) { - int count = count_a + count_b; + CountDataType count = count_a + count_b; T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; T delta = mean_b - mean_a; mean_a += delta * count_b_over_count; @@ -46,11 +47,12 @@ struct BlockwiseWelford count_a = count; } - __device__ static void Run(T& mean_value, T& var_value, int& count) + template + __device__ static void Run(T& mean_value, T& var_value, CountDataType& count) { __shared__ T mean_block_buf[BlockSize]; __shared__ T var_block_buf[BlockSize]; - __shared__ int count_block_buf[BlockSize]; + __shared__ CountDataType count_block_buf[BlockSize]; constexpr auto cluster_len_shift = get_shift(); @@ -76,13 +78,13 @@ struct BlockwiseWelford index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + make_tuple(0, indOffset)); - T mean1 = mean_block_buf[offset1]; - T var1 = var_block_buf[offset1]; - int count1 = count_block_buf[offset1]; + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + CountDataType count1 = count_block_buf[offset1]; - T mean2 = mean_block_buf[offset2]; - T var2 = var_block_buf[offset2]; - int count2 = count_block_buf[offset2]; + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + CountDataType count2 = count_block_buf[offset2]; Merge(mean1, var1, count1, mean2, var2, count2); diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index 2163ad32383d4f940ab4cf3b041ab2d5c61456cc..82667e235238170806846940e5d153ea20992b73 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/tensor_description/cluster_descriptor.hpp" -#include "ck/utility/reduction_common.hpp" +#include "ck/utility/get_shift.hpp" #include "ck/utility/reduction_functions_accumulate.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index 04ad75bd7de4a9657c2aaaf21b40205c3ff9c7d8..2c5fbc3937c5e3782156bf81f46b4fb4b1f1db46 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1 } } + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + template __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 5c47a49b38b6a30eebc9190cc338d4dbc0bc8524..905a59f56e3b42bfb4b2bdf8652a3e81b2872076 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..83cb9fb5de4e047a71bbcaf47f67df925b2dac74 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.template Run( + src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index aa33fc083f15cfd8904c8fa086a866dbc7817e7c..17110c8358dbaf7e602e7f9822968e08f80e14ef 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index eb5f589a4ada976c9be4a5001fb6fc288c7e9a43..9a5317dd126a818bc88043408427f0f105cb7f6e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp index 3bd7806389b59542b191e61f5369aebf7300fc6a..993d90e356def70ba54e357889277188b4f21971 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1a9bb3213a541c49358035b1b51bbf28aac26d87 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7r2 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs)); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dc08a2c88b2abbd9b26412121ae223af6c3e1c01 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace conv_tensor_rearrange_op { + +struct BaseConvTensorRearrangeOp +{ +}; + +struct ImageToColumn : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Image to Column"; +}; + +struct ColumnToImage : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Column to Image"; +}; + +template ::value, + bool>::type = false> +std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&) +{ + os << Op::name; + return os; +} + +} // namespace conv_tensor_rearrange_op +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp index a4a29f5d5edc3b840fc7489d5f706231990cad70..cab5e213651dfc0756e03a82ea4389042b1539ae 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat switch(s) { case ConvolutionBackwardDataSpecialization::Default: return "Default"; - case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: - return "FFilter1x1Stride1Pad0"; + case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; default: return "Unrecognized specialization!"; } } diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 20b2a152b9d276b0bd0f6e1aa5798b4817b013a9..01bb806789c3d9e50f706d7c88982b4a6624ff57 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index 953ff1e06ed07b24968c7f5c0842161ac66643ed..adfa1689c66509c8c194985365356740d4c90473 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a2b082c31aaffa401abc126b8770c3ee9f35686 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceAvgPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_k_wos_lengths, + std::vector dout_n_k_wos_strides, + std::vector din_n_k_wos_length, + std::vector din_n_k_wos_strides, + std::vector window_k_c_xs_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 5946daf21ec169612c3207282007e05671882f09..198169011107fb0f236d4657c399ff0534ce2c98 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp index 9fcd893c7a8c38d31f480b13eb8e7bfc97ec7fef..ee7af0117d17acc561e5943ac59611e188b6ba4c 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp index e755913280f73e325daf1352324bab8d3df8a3a9..6cc2c7bb2f6c176f2d84fdda4be2140db5564360 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp index af681127f30404f182475039add4443e46213398..91b4b6b91b6b6be5d08d88e5ca68c9b62fb02aea 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp index 116e62c00907e0f72031552f46b8a3ea975a18f5..f18dc3290600e63edb869d93dcb1a903206e32ab 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp index eacc5976d3ea3885a8294f1ee74a249f0a3d52a7..8234e29486b1af8ae0dd43e58d4be9b63b5ed34f 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp index c1f85e575ce4b2e0d70c05010bfca6e9ed9b84a8..09259224e75cc67e092eeeaab1af793b050cc1f0 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp index bde71806daa0543c41813f2afb0dd34f16bfb76e..be8105c96726dfb3d8931b376a668c3296040706 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp index d39f3b7cbcfc9ccaccca5f755be747cc4896dc1b..2c0da692570b455da42540bb72793b9e629cb823 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp index aa93dd9c19d77fbee7ccdafd1e41e4c4f64735cd..e3962e177ee824981c35fbe2ba04664e4124315b 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp index 8a00fd9db3327a7c536b1ada0ca090bfa847a256..69103b6f44297ed65ed4233e3d458ac6927bcd16 100644 --- a/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp index aedae53800b03cb78e6af5b15c8c6dfcd9792eea..8484212118c796f68d1bf066eab7b8ddbb231797 100644 --- a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "device_base.hpp" diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3116d404742e1769a1fc6c9b75a1f22f5eea689a --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M0, M1, ... K0, K1, ...], ... +// input : B0[N0, N1, ... K0, K1, ...], ... +// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... +// output : E[M0, M1, ... N0, N1, ...] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index dbc525c099bca293d106ac6d3c37ac72b2d749cc..118ade8978643f10dcfafbefb552fe485d6697e4 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp index 82054a3c9423f819d17a003b8c22844f0a2273fe..eb1b85ec822aa96822b5a288d55a5d3a0898f1e9 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp index 4b9881088dde07a62e92c5f77a19edc9c4b7670f..4dc11dbefd73619c39febe31a91a953b25049f05 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp index 5a627deeb2221f5e271532d45bc8a544c8cceeba..7d3845666cf2289f34b186a5aaa8ee616812f30a 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp index cc139303c929ae999d182ad09e9a4ef33f5209e1..3a49ac632e7de745e78ad5670dc57f425ab9eb33 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b28204fd841cd1430e30a255dacf5b9a4583a929 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Convolution Tensor Rearrange. + * + * This Device operator supports converting an image to + * the GEMM representation (Image to Column) and + * converting a GEMM form to the image (Column to Image). + * Supported layouts: + * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C] + * [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C] + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ImageLayout Input Layout. + * \tparam InputDataType Input Data Type. + * \tparam OutputDataType Output Data Type. + * \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage. + */ +template +struct DeviceConvTensorRearrange : public BaseOperator +{ + + /** + * \brief Make argument pointer for image to column. + * + * \param p_in A pointer to the device memory of the input image. + * \param p_out A pointer to the device memory of the output. + * \param G Convolution number of groups. + * \param N Convolution batch size. + * \param C Convolution number of channels. + * \param input_spatial_lengths Input spatial lengths. + * \param filter_spatial_lengths Filter spatial lengths. + * \param output_spatial_lengths Output spatial lengths. + * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. + * \param gemm_g_m_k_strides Gemm form strides. + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Convolution left pads. + * \param input_right_pads Convolution right pads. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_out, + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp index f9f913a7c1f1a6cdfdb41f075f2609e8e2407d07..db0e4bd83f46ad001492876ff8a712895a873c22 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp index 9491a92247c71e3938d0f7d81cc0d82741b693de..c56a947ec9e11a0a5dce2cd39ce86d8e6a3efb8c 100644 --- a/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm.hpp index c0af6f80faf606a22678e516be994a75c1d56eca..adf909821dbbdde12526513a7b6218e8ebbd219d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp index 4c2161eaed5b9c7d7913686f44c360c044ca3b07..a7f42c3b35e237cb4ad9edcbcd137f453af59051 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbb9fadc6d6f3f03673f4c5a7f13229f29b5d849 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp index 9113bb7b7454443cce4871bf263433d4f907bfba..a44356dc2405bfe6ed448a8193d55e86bd29415c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp index a67a09b87416bf11cc285ab1b1d9dc684b17103f..0258858fe50e2d433e5ddfe007e8d06227b47c0c 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp index f4881e32f620be9c31649085813a257e9e84a598..539e83f7cb8a5dfe99b2ecb3da491f9661330265 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index fcc088ca43d1b8c6224c0015e3eb7434038af4b2..eaa7671c6424b4ecf5ad87c4e790a9c478645b9b 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp index c701bff57f8eb7db155e9abdfd3ab7210e7eeffd..dfc27e726eaafaddf13ac62c530567eb32ab0ddc 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,7 +20,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> struct DeviceGemmSplitK : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_a, @@ -48,7 +49,8 @@ template + typename CElementwiseOperation, + typename ComputeType = CDataType> using DeviceGemmSplitKPtr = std::unique_ptr>; + CElementwiseOperation, + ComputeType>>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed081ad7fc806780dfebf92fbdcbb1e7f99dfb86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmStreamK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t NumSKBlocks = 0) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmStreamKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp index 173c613a325d8c594298ff6751060fecf2f8d453..ba81948440acd892a9877a7615f95642628b6347 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 3350aec8d3174010cf0d926c3a91802ce3b8e7e7..2abf1d5a10c7e8645ca51d842390d50d5fbce5dd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,7 +29,9 @@ template + typename CDEElementwiseOperation, + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 1258aed71c502c659f065d08e8b1a84922c26c83..7296e4faaa30c56cab55ac5140b94a20c45782b1 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,24 +20,25 @@ template + typename OutElementwiseOperation, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedConvBwdWeight : public BaseOperator { virtual std::unique_ptr MakeArgumentPointer(const void* p_in, void* p_wei, const void* p_out, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp index 644c7ee9a9107718dc609836e6cd8abe7f2dad21..025c43e75cc204b5bb8b3bdc1083d30aaf760bbc 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 1cc30fd9e6ad10336a49a261b3ee7c47f9053435..2ca82dc6da230d36643d472be94d1557e297aa3c 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -29,7 +29,8 @@ template + typename CDEElementwiseOperation, + typename ComputeType = ADataType> struct DeviceGroupedConvFwdMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 181ee4b428ba236d92cf71bbf33b24e53b2491c0..1e03405536a20a9f050a7943243f1c406594b836 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -1,4 +1,8 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + #pragma once + #include #include @@ -31,7 +35,7 @@ struct DeviceGroupedGemm : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); - static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); virtual std::unique_ptr MakeArgumentPointer(std::vector& p_a, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcb2ba6a4d7be839c11cf9794bb7beccf7845d3c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmKernelArgument +{ + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; +}; + +template +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemm +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp index b066a4458518b378313225e24817253922962d29..fae65097407cdb26df206308442874c0fc326fe6 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..06d180d30fb0f3c8e3b348dd1668ff582de87493 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp @@ -0,0 +1,39 @@ +#pragma once +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm +{ + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5a4a9cac1e8ea78b6683b1c4ce9629a8c87b9802 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// For pooling which used indexable operation, such as MaxPool, MinPool...etc +template +struct DeviceMaxPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp index ee4b53e2fcc43f44aba29c3409963a1b94834cbc..f68022ca04de757fee1240ced4b504601b183d91 100644 --- a/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index 03601ce8312ccea06d60236779aa78a9ea306e01..97e83ebab2efebe08871a757fc037512746b97df 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,8 +14,8 @@ namespace device { template @@ -27,6 +27,8 @@ struct DeviceNormalization : public BaseOperator const std::vector gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -43,16 +45,16 @@ struct DeviceNormalization : public BaseOperator template using DeviceNormalizationPtr = std::unique_ptr>; diff --git a/include/ck/tensor_operation/gpu/device/device_permute.hpp b/include/ck/tensor_operation/gpu/device/device_permute.hpp index 9daa2be37338c8309bd2131f263d324fe15b7ab2..c994cf02c6aab5b85fcd2fcc29f2716de28b861d 100644 --- a/include/ck/tensor_operation/gpu/device/device_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp deleted file mode 100644 index 3b376c6f73f066729fa118013970682ca31b340e..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp +++ /dev/null @@ -1,40 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" -#include "ck/utility/reduction_enums.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DevicePool2dFwd : public BaseOperator -{ - virtual std::unique_ptr - MakeArgumentPointer(const void* in_dev, - void* out_dev, - void* out_indices_dev, - ck::index_t N, - ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -template -using DevicePool2dFwdPtr = std::unique_ptr>; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..62071c43a1a0275d2868199af95d33354a365db5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePoolFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_n_c_wis_lengths, + std::vector window_xs_lengths, + std::vector output_n_c_wos_lengths, + std::vector input_n_c_wis_stride, + std::vector output_n_c_wis_stride, + std::vector indices_n_c_wis_stride, + std::vector window_xs_strides, + std::vector window_xs_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_put_element.hpp b/include/ck/tensor_operation/gpu/device/device_put_element.hpp new file mode 100644 index 0000000000000000000000000000000000000000..17df2de37b68c94cbd9ed9091cdfcc106ad8e287 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_put_element.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElement : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t output_length, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp index c9209f2d7d681b44c8a27dd06626eb8a4b998a39..c2721b18455c6a46e5973fad32767ae43247f59b 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/include/ck/tensor_operation/gpu/device/device_softmax.hpp index 94f788e5177cd78d0ea917b0869c576ba8fe7bfb..1902fd09eec1ed01954b6cad47b0e97ff898eb40 100644 --- a/include/ck/tensor_operation/gpu/device/device_softmax.hpp +++ b/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,7 +18,8 @@ template + index_t Rank, + index_t NumReduceDim> struct DeviceSoftmax : public BaseOperator { // @@ -49,8 +50,6 @@ struct DeviceSoftmax : public BaseOperator AccElementwiseOp acc_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; - virtual index_t GetRank() const = 0; - virtual index_t GetNumReduceDim() const = 0; }; template -using DeviceSoftmaxPtr = std::unique_ptr< - DeviceSoftmax>; + index_t Rank, + index_t NumReduceDim> +using DeviceSoftmaxPtr = std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp index f59e6093e2ae024ccbf8827082c755eba8cee88f..eeccd977ccbf4440a349845826d45d5a3274ad37 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp index fc913e9ba03611dfbabea020bd4054c0b953726d..0bb45b18c3e19b2ec5f9347c1e811d8734ee45a9 100644 --- a/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3a2280a75c27d897972c6f3149b8a0f4e63da7a9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp @@ -0,0 +1,575 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Di, Hi, Wi] +// Out and Dout = [N, C, Do, Ho, Wo] +// Out = AvgPoolFwd(In) +// Din = AvgPoolBwd(Dout) +// Pooling dimension = D, H, W +template +struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3, + DOutDataType, + DInDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC> +{ + static constexpr ck::index_t NDimSpatial = 3; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make3DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ztilde = tildes[0]; + index_t i_ytilde = tildes[1]; + index_t i_xtilde = tildes[2]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + + const index_t Di = din_n_c_wos_length[2]; + const index_t Hi = din_n_c_wos_length[3]; + const index_t Wi = din_n_c_wos_length[4]; + + const index_t Do = dout_n_c_wos_lengths[2]; + const index_t Ho = dout_n_c_wos_lengths[3]; + const index_t Wo = dout_n_c_wos_lengths[4]; + + const index_t Z = window_lengths[0]; + const index_t Y = window_lengths[1]; + const index_t X = window_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = window_strides[0]; + const index_t ConvStrideH = window_strides[1]; + const index_t ConvStrideW = window_strides[2]; + + const index_t ConvDilationD = window_dilations[0]; + const index_t ConvDilationH = window_dilations[1]; + const index_t ConvDilationW = window_dilations[2]; + + const auto out_n_do_ho_wo_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C), + make_tuple(dout_n_c_wos_strides[0], + dout_n_c_wos_strides[2], + dout_n_c_wos_strides[3], + dout_n_c_wos_strides[4], + dout_n_c_wos_strides[1])); + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = + math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // Out[ReduceM, ReduceK] + const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[4], + din_n_c_wos_strides[1])); + + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]); + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]); + + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ztilde, i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool3dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index 493822aeb27215406a8f7fad56021aca8e4246a7..4d599e8017991b4f1b1905de9acd23a219ac67f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 9bf8f5ccd9ccecf14e1e0fdd005c331124f32547..32c45bc57e0fa6735a24f83e7ecfb62755933690 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -543,9 +543,13 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -588,14 +592,18 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -840,8 +848,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 20184458699811ce484863f4327539b80c64872e..ba22cf0bf868bf353cc4c550bc9c17ad5c2b1d03 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -75,7 +75,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -331,8 +331,13 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, // DsDataType, @@ -378,13 +383,16 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; // Argument @@ -571,6 +579,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 20e9920d935c12f54dd7ea89af3279b829892fe7..3dbe8c67226ec4e4f0dc5bf4ac02ac49ae73cf46 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,7 +61,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -589,8 +589,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -580,8 +588,7 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + */ + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD + +{ + using DeviceOp = DeviceBatchedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + K_(K), + Batch_(Batch), + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = + DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t K_; + + // Batch + index_t Batch_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultipleD_Dl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.Batch_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = + kernel_gemm_dl_multiple_d; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // TODO: Enable for gfx90a after complier fix + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") + { + bool pass = true; + pass = pass && arg.K_ % K1 == 0; + + pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.e_grid_desc_m_n_); + + return pass; + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 196dc86da15411e95a17ea05eb484f716561c383..9455ec48b33419d49837a425eca08188216bb729 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -69,7 +69,7 @@ __global__ void const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -196,6 +196,8 @@ template ; - using A0GridDesc_AK0_M_AK1 = remove_cvref_t; - using B0GridDesc_BK0_N_BK1 = remove_cvref_t; - using B1GridDesc_BK0_N_BK1 = remove_cvref_t; + using A0GridDesc_AK0_M_AK1 = + remove_cvref_t; + using B0GridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; // Argument struct Argument : public BaseArgument @@ -805,8 +812,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index ef9b90ba7cde65bf17336325be3c76e33c758e68..c4567108c9b2904f9a2b7e20edd32f4e4bd0d97b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -60,7 +60,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -801,6 +801,11 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceO static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 0c6c0ef7ad2ded496c0f03618a64dac8a2783b81..c38b7e1c5bc322a1765fdff02c0f26db205bb5e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -68,7 +68,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -197,7 +197,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + LoopScheduler LoopSched = LoopScheduler::Default> struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle : public DeviceBatchedGemmSoftmaxGemmPermute; + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + D0sTransferSrcScalarPerVector>; // Argument // FIXME: constness @@ -530,6 +532,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle using D0DataType = remove_cvref_t>; // D0 pointer p_d0s_grid_(i) = static_cast(p_acc0_biases[i]); + // for check + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]); + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); }); if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, @@ -608,6 +615,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle std::vector b_nz_kz_strides_; std::vector b1_nz_kz_strides_; std::vector c_mz_gemm1nz_strides_; + std::array, NumD0Tensor> d0s_nl_ns_lengths_strides_; index_t batch_count_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; @@ -715,8 +723,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.Print(); #endif - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } @@ -772,6 +779,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle { return false; } + for(int i = 0; i < NumD0Tensor; i++) + { + if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 && + arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0) + { + return false; + } + if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1) + { + return false; + } + } return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index 84edde63eceb141198ed249879db48131ff14d5c..f12e05fd349e41b0b06bf48bea091751d6325d63 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -63,7 +63,7 @@ __global__ void const C0MatrixMask c0_matrix_mask) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -613,8 +613,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index d35f194171e710b490a70906ed2cf2079ab7b77a..303eba156e328e16ce3dec4a24f88021432bf477 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,75 +45,46 @@ namespace device { * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). * */ -template +template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_c_grid + c_batch_offset, + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = compute_ptr_offset_of_batch; - ignore = block_2_ctile_map; + ignore = karg; #endif } @@ -171,93 +142,6 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - const auto a_grid_desc_k0_mp_k1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_k0_mp_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto b_grid_desc_k0_np_k1 = - transform_tensor_descriptor(b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_k0_np_k1; - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - const auto c_grid_desc_mp_np = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return c_grid_desc_mp_np; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -289,121 +173,82 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - NumGemmKPrefetchStage, - LoopSched, - PipelineVer>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpecialization::MNKPadding, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using Problem = typename GridwiseGemm::Problem; // Argument - struct Argument : public BaseArgument + struct Argument : public Problem, public BaseArgument { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, - index_t Batch, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - Batch_(Batch), - a_grid_desc_k0_m_k1_{ - DeviceBatchedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA)}, - b_grid_desc_k0_n_k1_{ - DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, - c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideC}, - block_2_ctile_map_{ - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} + index_t Batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC} { - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - index_t Batch_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; }; // Invoker @@ -411,107 +256,39 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + gdx *= karg.Batch; float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_batched_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.Batch_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.compute_ptr_offset_of_batch_, - arg.block_2_ctile_map_); + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -531,17 +308,14 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm(static_cast(p_a), static_cast(p_b), @@ -619,12 +385,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm" << " NumGemmKPrefetchStage: " << NumGemmKPrefetchStage << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp index ab16a757f699bb4e4ba781873bbf87538633d9c1..f46237e005642d433c29f5ab0f027a224043dbef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp index 5a16ff765bbc5c0bc1b68d80262c88a0c8f84416..ad8e7956033db31db2992d6fa9f86cf16a29b150 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,12 +10,14 @@ #include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" #include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp" #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" -#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" namespace ck { namespace tensor_operation { @@ -114,8 +116,8 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdinvariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + + // workspace for barrier objects, each barrier object consists of two integers + // TODO: allocate barrier object memory globally to reuse it by other operators + workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * + sizeof(int) * 2; } return (workspace_size); @@ -353,7 +362,6 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdblkGroupSize_ > 1) { - // setup buffer used for intermediate welford mean pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); @@ -374,6 +382,18 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwdworkspace_count_ = reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + + index_t count_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t); + + count_space_sz = math::integer_least_multiple(count_space_sz, 64); + + pArg_->control_ = reinterpret_cast(pArg_->workspace_count_) + count_space_sz; + + index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) / + M_BlockTileSize * sizeof(int) * 2; + + hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz)); }; }; @@ -402,6 +422,32 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd; + using GridwiseMultiblockWelfordFirstHalf_ = GridwiseMultiblockWelfordFirstHalf; - index_t numMeanVarCountBlockTileIteration = - (arg.blkGroupSize_ + KThreadClusterSize - 1) / KThreadClusterSize; - - const auto kern_multiblock_welford_first_half = - kernel_multiblock_welford_first_half; - - const auto kern_welford_second_half_batchnorm_forward_final = - kernel_welford_second_half_batchnorm_forward_final< - GridwiseWelfordSecondHalfBatchNormForwardFinal_, - XDataType, - YDataType, - AccDataType, - ScaleDataType, - BiasDataType, - MeanVarDataType, - YElementwiseOp, - XYGridDesc_M_K, - MeanVarCountGridDesc_M_K, - ScaleBiasMeanVarGridDesc_M, - ScaleBiasMeanVarGridDesc_M>; - - avg_time += - launch_and_time_kernel(stream_config, - kern_multiblock_welford_first_half, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - mean_var_count_grid_desc_m_g, - get_reduce_count_per_thread, - arg.numBlockTileIteration_, - arg.p_x_, - static_cast(arg.workspace_mean_), - static_cast(arg.workspace_variance_), - static_cast(arg.workspace_count_)); - - avg_time += - launch_and_time_kernel(stream_config, - kern_welford_second_half_batchnorm_forward_final, - dim3(arg.gridSize_), - dim3(BlockSize), - 0, - arg.x_grid_desc_m_k_, - arg.y_grid_desc_m_k_, - mean_var_count_grid_desc_m_k, - arg.scale_grid_desc_m_, - arg.bias_grid_desc_m_, - arg.mean_var_grid_desc_m_, - arg.blkGroupSize_, - arg.numBlockTileIteration_, - numMeanVarCountBlockTileIteration, - arg.epsilon_, - static_cast(arg.workspace_mean_), - static_cast(arg.workspace_variance_), - static_cast(arg.workspace_count_), - arg.p_x_, - arg.p_scale_, - arg.p_bias_, - arg.y_elementwise_op_, - arg.p_y_, - arg.updateMovingAverage_, - arg.averageFactor_, - arg.resultRunningMean_, - arg.resultRunningVariance_, - arg.saveMeanInvVariance_, - arg.resultSaveMean_, - arg.resultSaveInvVariance_); + // It is found that: + // 1) gfx1030 does not support the GLC enabled vector load/store, so using the + // two-kernel method for gfx1030 + // 2) Profiler on gfx908 could hang even though it works when running examples + // 3) Single-kernel method works on gfx1100, but the performance it not better + // than two-kernel method (due to more warps participating the barrier) + if(ck::get_device_name() == "gfx90a") + { + const auto kern_multiblock_batchnorm_fwd_ = + kernel_multiblock_batchnorm_forward; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_batchnorm_fwd_, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, // for writing to mean/variance/count + // workspace by multiple workgroups + mean_var_count_grid_desc_m_k, // for reading from mean/variance/count + // workspace by each workgroup + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + static_cast(arg.control_), + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; } else { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b826793c27be946a99d0d96a99d745fd6c618822 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = make_naive_tensor_descriptor( + make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = make_naive_tensor_descriptor( + make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 16 + if(testBlkGroupSize <= 16) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += + launch_and_time_kernel(stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += + launch_and_time_kernel(stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 29978458bb202297868af555b3ffecb617e48fb7..a095521161c76d7bc831c76987a07691c415895c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -118,288 +118,21 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle return PadDescriptor_M_1d(desc_m, gridSize, blockSize); } - static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto MPad = M - MRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_right_pad_transform(MRaw, MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - // not pad M or K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto NPad = N - NRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(NRaw, NPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // not pad N or K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - } - - static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideC)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< - ADataType, // TODO: distinguish A/B datatype + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, InMemoryDataOperationEnum::Set, - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, - CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -433,108 +166,82 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched>; + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + // Argument - struct Argument : public BaseArgument + struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem { - Argument(const ADataType* p_a_grid_real, - const ADataType* p_a_grid_imag, - const BDataType* p_b_grid_real, - const BDataType* p_b_grid_imag, - CDataType* p_c_grid_real, - CDataType* p_c_grid_imag, + using Problem = typename GridwiseGemm::Problem; + + Argument(const ADataType* p_a_grid_real_, + const ADataType* p_a_grid_imag_, + const BDataType* p_b_grid_real_, + const BDataType* p_b_grid_imag_, + CDataType* p_c_grid_real_, + CDataType* p_c_grid_imag_, CDataType* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_real_{p_a_grid_real}, - p_a_grid_imag_{p_a_grid_imag}, - p_b_grid_real_{p_b_grid_real}, - p_b_grid_imag_{p_b_grid_imag}, - p_c_grid_real_{p_c_grid_real}, - p_c_grid_imag_{p_c_grid_imag}, - p_aux_grid_{p_workspace}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op} + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid_real{p_a_grid_real_}, + p_a_grid_imag{p_a_grid_imag_}, + p_b_grid_real{p_b_grid_real_}, + p_b_grid_imag{p_b_grid_imag_}, + p_c_grid_real{p_c_grid_real_}, + p_c_grid_imag{p_c_grid_imag_}, + p_aux_grid{p_workspace} { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } - - const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + const index_t grid_size = std::get<1>(GridwiseGemm::CalculateGridSize(M_, N_)); if constexpr(is_same::value) { - c_grid_desc_m_ = - DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + c_grid_desc_m = + DeviceOp::MakeDescriptor_M({M_, N_}, {StrideC_, I1}, grid_size, BlockSize); } else if constexpr(is_same::value) { - c_grid_desc_m_ = - DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + c_grid_desc_m = + DeviceOp::MakeDescriptor_M({M_, N_}, {I1, StrideC_}, grid_size, BlockSize); } - p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); + p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_); } // private: - const ADataType* p_a_grid_real_; - const ADataType* p_a_grid_imag_; - const BDataType* p_b_grid_real_; - const BDataType* p_b_grid_imag_; - CDataType* p_c_grid_real_; - CDataType* p_c_grid_imag_; - CDataType* p_aux_grid_; - CDataType* p_aux_2_grid_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_M c_grid_desc_m_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; + const ADataType* p_a_grid_real; + const ADataType* p_a_grid_imag; + const BDataType* p_b_grid_real; + const BDataType* p_b_grid_imag; + CDataType* p_c_grid_real; + CDataType* p_c_grid_imag; + CDataType* p_aux_grid; + CDataType* p_aux_2_grid; + CGridDesc_M c_grid_desc_m; }; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; float ave_time = 0; @@ -578,224 +285,154 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - true>; - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_real_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_imag_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( stream_config, subtract_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_real_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), Subtract{}); - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_imag_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_real_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( stream_config, add_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_imag_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), Add{}); } else { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - false>; - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_real_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_imag_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); // c_real = aux - aux_2 ave_time += launch_and_time_kernel( stream_config, subtract_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_real_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), Subtract{}); - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_real_, - arg.p_b_grid_imag_, - arg.p_aux_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_imag_, - arg.p_b_grid_real_, - arg.p_aux_2_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); // c_imag = aux + aux_2 ave_time += launch_and_time_kernel( stream_config, add_kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - make_tuple(arg.c_grid_desc_m_, arg.c_grid_desc_m_), - make_tuple(arg.c_grid_desc_m_), - make_tuple(const_cast(arg.p_aux_grid_), - const_cast(arg.p_aux_2_grid_)), - make_tuple(arg.p_c_grid_imag_), + make_tuple(arg.c_grid_desc_m, arg.c_grid_desc_m), + make_tuple(arg.c_grid_desc_m), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), Add{}); } @@ -818,10 +455,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); } // polymorphic @@ -837,15 +476,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle CDataType* p_c_real, CDataType* p_c_imag, CDataType* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) { return Argument{p_a_real, p_a_imag, @@ -854,15 +493,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle p_c_real, p_c_imag, p_workspace, - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op}; + StrideC}; } static auto MakeInvoker() { return Invoker{}; } @@ -875,15 +511,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle void* p_c_real, void* p_c_imag, void* p_workspace, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, index_t /* KBatch */ = 1) override { return std::make_unique(static_cast(p_a_real), @@ -893,15 +529,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle static_cast(p_c_real), static_cast(p_c_imag), static_cast(p_workspace), - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + StrideC); } // polymorphic @@ -930,16 +563,22 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle return str.str(); } - std::size_t GetWorkspaceSize(index_t MRaw, - index_t NRaw, - [[maybe_unused]] index_t KRaw, + static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( + M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC); + + return c_grid_desc_m_n.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSize(index_t M, + index_t N, + [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, index_t StrideC) override { - const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); - - return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize(); + return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..567be5f36476da43f14958eab9f3da0438532654 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Column to Image: +// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C] +// output : input image [G, N, Di, Hi, Wi, C] +// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C] +// output : input image [N, Di, Hi, Wi, G, C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceColumnToImageImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number{}; + static constexpr auto XIdx = Number{}; + + static constexpr auto spatial_offset = Number<3>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Calculate number of independent filters for given conv params + static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len, + const index_t left_pad, + const index_t right_pad, + const index_t filter_len, + const index_t filter_stride, + const index_t filter_dilation, + const index_t image_offset) + { + const index_t x_eff = (filter_len - 1) * filter_dilation + 1; + const index_t next_filter_padded = + math::integer_divide_ceil(x_eff, filter_stride) * filter_stride; + // If filter_stride >= x_eff then each filter is independent + const index_t independent_filter_stride = + filter_stride >= x_eff ? filter_stride : next_filter_padded; + const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff; + // There are no independent filters + if(w_eff < 0) + return 0; + const index_t independent_kernels_num = w_eff / independent_filter_stride + 1; + return independent_kernels_num; + } + + // Make column form descriptor + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& gemm_g_m_k_strides, + const std::array& independent_filters, + const std::array& effs) + { + const index_t DoHoWo = ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2]; + // Calculate the appropriate stride for each set of independent filters + // in each dimension + const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * + gemm_g_m_k_strides[I1]; + const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * + output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1]; + const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * + output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * + gemm_g_m_k_strides[I1]; + // Create descriptor for independent filters in each dimension and + // then merge them into column form + if constexpr(NDimSpatial == 1) + { + const auto desc_gemm_form = + make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), + make_tuple(NStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 2) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), + make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 3) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx], + CZYX), + make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + } + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& image_offsets, + const std::array& independent_filters, + const std::array& effs) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + // Calculate descriptor only for independent filters + copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + // Modify pads to apply offsets + std::array input_left_pads_with_offset; + for(index_t i = 0; i < NDimSpatial; i++) + { + input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]); + } + // Modify input spatial lengths to apply offsets + for(index_t i = 0; i < NDimSpatial; i++) + { + a_g_n_c_wis_lengths[i + spatial_offset] -= + math::max(0, image_offsets[i] - input_left_pads[i]); + } + + // Strides to next independent filters + std::array independent_filter_strides; + for(index_t i = 0; i < NDimSpatial; i++) + { + index_t independent_filter_stride = + math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i]; + // If conv stride is greater than whole filter size, use conv stride + independent_filter_strides[i] = conv_filter_strides[i] >= effs[i] + ? conv_filter_strides[i] + : independent_filter_stride; + } + + // Calculate image form descriptor for the modified convolution problem + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K( + a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + // conv_filter_strides, + independent_filter_strides, + conv_filter_dilations, + input_left_pads_with_offset, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + InputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = + GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0]; + + const index_t x_eff = + (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; + const index_t y_eff = + NDimSpatial < 2 + ? I1 + : (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1; + const index_t z_eff = + NDimSpatial < 3 + ? I1 + : (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1; + + // Iterate over sets of independent filters + for(int z_img_offset = 0; z_img_offset < z_eff; + z_img_offset += conv_filter_strides[ZIdx]) + { + for(int y_img_offset = 0; y_img_offset < y_eff; + y_img_offset += conv_filter_strides[YIdx]) + { + for(int x_img_offset = 0; x_img_offset < x_eff; + x_img_offset += conv_filter_strides[XIdx]) + { + + std::array image_offsets; + std::array effs; + // Calculate the starting offset for a given set of + // independent filters + if constexpr(NDimSpatial == 1) + { + image_offsets = {x_img_offset}; + effs = {x_eff}; + } + if constexpr(NDimSpatial == 2) + { + image_offsets = {y_img_offset, x_img_offset}; + effs = {y_eff, x_eff}; + } + else if constexpr(NDimSpatial == 3) + { + image_offsets = {z_img_offset, y_img_offset, x_img_offset}; + effs = {z_eff, y_eff, x_eff}; + } + + std::array independent_filters; + for(index_t i = 0; i < NDimSpatial; i++) + { + independent_filters[i] = + GetNumberOfIndependentFilters(input_spatial_lengths[i], + input_left_pads[i], + input_right_pads[i], + filter_spatial_lengths[i], + conv_filter_strides[i], + conv_filter_dilations[i], + image_offsets[i]); + } + const index_t independent_filters_acum = ck::accumulate_n( + independent_filters.begin(), NDimSpatial, 1, std::multiplies<>()); + if(independent_filters_acum <= 0) + continue; + + const auto in_grid_desc_m_k = + MakeInputDescriptor_M_K(N, + C, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + gemm_g_m_k_strides, + independent_filters, + effs); + const auto out_grid_desc_m_k = + MakeOutDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + image_offsets, + independent_filters, + effs); + in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k); + out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k); + + const index_t x_idx = x_img_offset / conv_filter_strides[XIdx]; + const index_t y_idx = y_img_offset / conv_filter_strides[YIdx]; + const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx]; + + const index_t x_offset_with_pad = + math::max(0, x_img_offset - input_left_pads[XIdx]); + const index_t y_offset_with_pad = + math::max(0, y_img_offset - input_left_pads[YIdx]); + const index_t z_offset_with_pad = + math::max(0, z_img_offset - input_left_pads[ZIdx]); + + // Memory offsets to next set of independent filters, + // move to independent filters in each dimension + const index_t in_offset = + (x_idx + y_idx * output_spatial_lengths[XIdx] + + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) * + gemm_g_m_k_strides[I1]; + // Move to independent filters in appropriate dimensions + const index_t out_offset = + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + + y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] + + z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx]; + + const InputDataType* p_in_with_offset = + static_cast(p_in) + in_offset; + OutputDataType* p_out_with_offset = + static_cast(p_out) + out_offset; + p_in_container_.push_back(p_in_with_offset); + p_out_container_.push_back(p_out_with_offset); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++) + { + std::cout << in_grid_desc_m_k_container_[i] << std::endl; + std::cout << out_grid_desc_m_k_container_[i] << std::endl; + } + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + std::vector in_grid_desc_m_k_container_; + std::vector out_grid_desc_m_k_container_; + + std::vector p_in_container_; + std::vector p_out_container_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float elapsed_time = 0.f; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + // Execute each set of independent filters + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_container_[i]); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_; + elapsed_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_container_[i], + arg.p_in_container_[i], + arg.out_grid_desc_m_k_container_[i], + arg.p_out_container_[i], + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + } + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + using namespace tensor_layout::convolution; + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + bool valid = true; + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + valid &= GridwiseTensorRearrangeKernel::CheckValidity( + arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]); + } + return valid; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceColumnToImage" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..309a32bf2f73cb369a20b2df67908634011e1b44 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,847 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD_Xdl_CShuffle + : public DeviceContractionMultipleABD +{ + using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, + const std::vector& a_ms_ks_strides_) + { + assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && + a_ms_ks_strides_.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); + }, + Number{}); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, + const std::vector& b_ns_ks_strides_) + { + assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && + b_ns_ks_strides_.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); + }, + Number{}); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, + const std::vector& e_ms_ns_strides_) + { + assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && + e_ms_ns_strides_.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto + MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AsGridDesc_M_K = remove_cvref_t; + using BsGridDesc_N_K = remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t; + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + // using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + // using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + + // for sanity check of vector memory access + for(index_t i = 0; i < NumATensor; ++i) + { + a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; + a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; + } + + for(index_t i = 0; i < NumBTensor; ++i) + { + b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; + b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; + } + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; + } + + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + std::array a_mz_stride_; + std::array a_kz_stride_; + + std::array b_nz_stride_; + std::array b_kz_stride_; + + std::array ds_nz_stride_; + index_t e_nz_stride_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + bool all_valid = true; + + static_for<0, NumATensor, 1>{}([&](auto i) { + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % + ABlockTransferSrcScalarPerVector == + 0)) + { + all_valid = false; + } + } + else + { + if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % + ABlockTransferSrcScalarPerVector == + 0)) + { + all_valid = false; + } + } + }); + + // vector memory access of B: could be on N or BK1 dimension + static_for<0, NumBTensor, 1>{}([&](auto i) { + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % + BBlockTransferSrcScalarPerVector == + 0)) + { + all_valid = false; + } + } + else + { + if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) % + BBlockTransferSrcScalarPerVector == + 0)) + { + all_valid = false; + } + } + }); + + // check vector load of Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + all_valid = false; + } + }); + + // vector memory access of E: always on NPerBlock dimension + if(!(arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + all_valid = false; + } + + if(!all_valid) + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + d_ms_ns_lengths, + d_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides, + const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + as_ms_ks_lengths, + as_ms_ks_strides, + bs_ns_ks_lengths, + bs_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceContractionMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 1eaffe705a74c9cdffc9a29f912f134c2bffbd97..5ae06836fab74e27e70ea5f04c2ea1f66f8540b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -53,7 +53,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -310,9 +310,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -355,14 +359,18 @@ struct DeviceContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -582,8 +590,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index b65afce8df5834d14f4ee9b3b19730e11db43bee..23440e24f6828c55b0160c3cb06ee667da9d3f60 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -532,11 +532,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ float ave_time = 0; const auto Run = [&](const auto& kernel) { - hipGetErrorString(hipMemset( + hipGetErrorString(hipMemsetAsync( arg.p_c_grid_, 0, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); + sizeof(CDataType), + stream_config.stream_id_)); ave_time = launch_and_time_kernel(stream_config, @@ -649,6 +650,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + // vector load A/B matrix from global memory if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index ea3020663a88c1ce565beab17313e72928f50461..e22c5a2aa514d4ace3b05074be06a539a7b54433 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -379,9 +379,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -428,20 +425,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -495,18 +482,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -517,14 +492,6 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; - OutElementwiseOperation a_element_op_; - WeiElementwiseOperation b_element_op_; - InElementwiseOperation c_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -567,103 +534,68 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -684,6 +616,11 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -716,8 +653,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -742,10 +678,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -759,12 +692,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -783,9 +711,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -799,12 +727,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index e7e2bf33540fd8dd05ba1062d32241d39ff371cd..c9e8940edc4030be672d7c9def1225c9f650eee5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -810,6 +810,11 @@ struct static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 6c4957b9b17f68c3a5d33470582d1d67907a8e85..28fceb428eff34628cc2f5c680bd9a45c0b856e6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -767,6 +767,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 027f1a1954da565721a1bf072f58e8feddcf29ca..ca291d3b11f5273caaf7900bf349eddceb108658 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -741,6 +741,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 6278220c2f9dee047f632acef0e0d2f911156fa6..ef94120f4e89f199018fc62c5af481fc0aca4f83 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -329,9 +329,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -378,25 +375,13 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_in_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_out_grid}, a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - in_element_op_{in_element_op}, - wei_element_op_{wei_element_op}, - out_element_op_{out_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -420,17 +405,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } } // private: @@ -440,14 +414,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - InElementwiseOperation in_element_op_; - WeiElementwiseOperation wei_element_op_; - OutElementwiseOperation out_element_op_; // for checking IsSupportedArgument() index_t Conv_N_; index_t Conv_K_; @@ -479,17 +445,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } #endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -498,22 +461,18 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -521,30 +480,22 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r3; ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg.p_a_grid_, @@ -552,11 +503,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K arg.p_c_grid_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); + arg.c_grid_desc_m_n_); } return ave_time; @@ -577,6 +524,11 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) { @@ -616,10 +568,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_); + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -639,10 +589,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -656,12 +603,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -680,9 +622,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -696,12 +638,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index f69d8f18ae036b8e7cd060a80a0cca12bfc2a050..cd89f3232cb94e9ce9a68c06587758ac9574816f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_NAIVE_HPP #define DEVICE_CONV3D_FWD_NAIVE_HPP diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index d52879cd904f05d165557c08ae6d123fbfa18434..a8e586b20c38a40438636c5f9cafafb5c34409a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP @@ -56,7 +56,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -524,6 +524,11 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 2a2edc29bac75fa7c7e5db352fcaa934754cdae4..3178f73f4b530e67b399cd0ae71064a4b66fa837 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -1393,7 +1393,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl static bool IsSupportedArgument(const Argument& arg) { // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) + if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 1fd4b76cec4cad5fd23b6a64166d210abdcb902a..ee3f0cea1b37f1de5ce3f53d6252e62024c8ddf8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -980,9 +980,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, @@ -1029,20 +1026,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) : p_a_grid_{p_out_grid}, p_b_grid_{p_wei_grid}, p_c_grid_{p_in_grid}, - M01_{M01}, - N01_{N01}, - a_element_op_{out_element_op}, - b_element_op_{wei_element_op}, - c_element_op_{in_element_op}, Conv_N_{N}, Conv_K_{K}, Conv_C_{C}, @@ -1092,17 +1079,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } template ::type = false> @@ -1150,18 +1126,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1218,19 +1182,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - - auto block_2_ctile_map = - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); - - if(GridwiseGemm::CheckValidity( - descs[I0], descs[I1], descs[I2], block_2_ctile_map)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( - descs[I2])); - - block_2_ctile_map_container_.push_back(block_2_ctile_map); - } } } } @@ -1242,11 +1193,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector a_grid_desc_k0_m_k1_container_; std::vector b_grid_desc_k0_n_k1_container_; std::vector c_grid_desc_m_n_container_; - std::vector - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; - std::vector block_2_ctile_map_container_; - index_t M01_; - index_t N01_; OutElementwiseOperation a_element_op_; WeiElementwiseOperation b_element_op_; InElementwiseOperation c_element_op_; @@ -1276,123 +1222,84 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { #if DEBUG_LOG { - std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + std::cout << "arg.a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + std::cout << "arg.b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" << std::endl; - std::cout << "arg.c_grid_desc_m_n_container_{ " + std::cout << "arg.c_grid_desc_m_n{" << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6) - << ", " - << arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7) - << " ) " << std::endl; } #endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( - arg.c_grid_desc_m_n_container_[i]); + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - true>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - OutElementwiseOperation, - WeiElementwiseOperation, - InElementwiseOperation, - remove_reference_t, - false>; - - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_container_[i], - arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i], - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_container_[i]); + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); } } return ave_time; @@ -1413,6 +1320,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) { @@ -1446,8 +1358,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], - arg.c_grid_desc_m_n_container_[i], - arg.block_2_ctile_map_container_[i])) + arg.c_grid_desc_m_n_container_[i])) { return false; } @@ -1472,10 +1383,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + std::vector input_right_pads) { return Argument{p_in_grid, p_wei_grid, @@ -1489,12 +1397,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; + input_right_pads}; } static auto MakeInvoker() { return Invoker{}; } @@ -1513,9 +1416,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), @@ -1529,12 +1432,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl conv_filter_strides, conv_filter_dilations, input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); + input_right_pads); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp index 83ed6198bd3c0921988d0d867a87a617bc2c76a1..02ef29e32ddc2bbd985fb4f46c6eb3825c00c11a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" namespace ck { namespace tensor_operation { @@ -171,10 +172,7 @@ struct DeviceElementwise2dImpl : public DeviceElementwise 0, ""); static_assert(NumDim_n > 0, ""); @@ -192,34 +190,10 @@ struct DeviceElementwise2dImpl : public DeviceElementwise(out_dev_buffers[I.value]); }, Number{}); - - in_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - inStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); - - out_grid_2d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_MN(lengths, - outStridesArray[I.value], - gridSize_, - blockSize_, - num_threads_m_, - num_threads_n_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid2dDescTuple in_grid_2d_desc_tuple_; - OutGrid2dDescTuple out_grid_2d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -227,15 +201,38 @@ struct DeviceElementwise2dImpl : public DeviceElementwise{}); + + auto out_grid_2d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_MN(arg.lengths_, + arg.outStridesArray_[I.value], + gridSize, + arg.blockSize_, + num_threads_m, + num_threads_n); + }, + Number{}); + const auto kernel = kernel_elementwise_2d(out_dev_buffers[I.value]); }, Number{}); - - in_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, inStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); - - out_grid_1d_desc_tuple_ = generate_tuple( - [&](auto I) { - return MakeDescriptor_M( - lengths, outStridesArray[I.value], gridSize_, blockSize_); - }, - Number{}); } InDataTypePointerTuple in_dev_buffers_; OutDataTypePointerTuple out_dev_buffers_; - InGrid1dDescTuple in_grid_1d_desc_tuple_; - OutGrid1dDescTuple out_grid_1d_desc_tuple_; std::array lengths_; std::array, NumInput> inStridesArray_; @@ -187,13 +171,28 @@ struct DeviceElementwiseImpl ElementwiseOperation elementwise_op_; index_t blockSize_; - index_t gridSize_; }; struct Invoker : public BaseInvoker { float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + const auto kernel = kernel_elementwise_1d(); }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<" ; + str << "NumDim_" << NumDim << ","; + str << "MPerThread_" << MPerThread << ","; + + str << "InScalarPerVector"; + static_for<0, InScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << InScalarPerVectorSeq::At(i).value; }); + str << ","; + str << "OutScalarPerVector"; + static_for<0, OutScalarPerVectorSeq::Size(), 1>{}([&](auto i) { str << "_" << OutScalarPerVectorSeq::At(i).value; }); + + str << ">"; + // clang-format on + + return str.str(); + } + }; // namespace device } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp index 1fa69288a4d725304181698af5b13f848b6fa5c7..c3416758d1277d6cc7d69e4aa4115b57af4c9cd1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp index a9a58c8ac425876e8b5be1bbe3a9f1ff0c08dcd0..5bfef60af285e4ada6bb5d98a2762ebe52e52fb2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -683,6 +683,11 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp deleted file mode 100644 index 9f9fe0f1c9cb425492d2e796da976f565b265f34..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp +++ /dev/null @@ -1,586 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_bias_e_permute(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); -#else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; -#endif -} - -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace device { - -// input : A[M, K], or A[K, N] -// input : B[K, N], or A[N, K] -// input : D0[M, N], D1[M, N], ... -// output : E[M, N] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -template -struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute -{ - using DeviceOp = DeviceGemmBiasEPermute_Xdl; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; - - static constexpr index_t NumDTensor = 1; - - static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - static auto MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1 d_e_grid_desc) - { - index_t M0 = d_e_grid_desc.M0_; - index_t M1 = d_e_grid_desc.M1_; - index_t M2 = d_e_grid_desc.M2_; - index_t N0 = d_e_grid_desc.N0_; - index_t N1 = d_e_grid_desc.N1_; - - index_t stride_M0 = d_e_grid_desc.stride_M0_; - index_t stride_M1 = d_e_grid_desc.stride_M1_; - index_t stride_M2 = d_e_grid_desc.stride_M2_; - index_t stride_N0 = d_e_grid_desc.stride_N0_; - index_t stride_N1 = d_e_grid_desc.stride_N1_; - - const auto e_grid_desc_mraw_nraw = [&]() { - const auto e_grid_desc_m0_m1_m2_n0_n1 = make_naive_tensor_descriptor( - make_tuple(M0, M1, M2, N0, N1), - make_tuple(stride_M0, stride_M1, stride_M2, stride_N0, stride_N1)); - - return transform_tensor_descriptor( - e_grid_desc_m0_m1_m2_n0_n1, - make_tuple(make_merge_transform(make_tuple(M0, M1, M2)), - make_merge_transform(make_tuple(N0, N1))), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - }(); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(DEGridDesc_M0_M1_M2_N0_N1{})); - - using DsGridDesc_M_N = Tuple; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; - - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - - using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; - - // Argument - struct Argument : public BaseArgument - { - Argument(const void* p_a_grid, - const void* p_b_grid, - const void* p_d_grid, - void* p_e_grid, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : p_a_grid_{static_cast(p_a_grid)}, - p_b_grid_{static_cast(p_b_grid)}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_grid_desc)}, - a_grid_desc_ak0_m_ak1_{ - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{ - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - - if(MRaw != d_grid_desc.M0_ * d_grid_desc.M1_ * d_grid_desc.M2_) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - if(NRaw != d_grid_desc.N0_ * d_grid_desc.N1_) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - // populate pointer, desc for Ds - // D pointer - p_ds_grid_(I0) = static_cast(p_d_grid); - - // D desc - ds_grid_desc_m_n_(I0) = DeviceOp::MakeEGridDescriptor_M_N(d_grid_desc); - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, - b_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_(I0) = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_[I0]); - } - } - - // private: - // pointers - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_gemm_bias_e_permute< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2ETileMap, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_); - }; - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, - arg.b_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const void* p_a, - const void* p_b, - const void* p_d, - void* p_e, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{p_a, - p_b, - p_d, - p_e, - MRaw, - NRaw, - KRaw, - StrideA, - StrideB, - d_grid_desc, - e_grid_desc, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(const void* p_a, - const void* p_b, - const void* p_d, - void* p_e, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - DEGridDesc_M0_M1_M2_N0_N1 d_grid_desc, - DEGridDesc_M0_M1_M2_N0_N1 e_grid_desc, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) override - { - return std::make_unique(p_a, - p_b, - p_d, - p_e, - MRaw, - NRaw, - KRaw, - StrideA, - StrideB, - d_grid_desc, - e_grid_desc, - a_element_op, - b_element_op, - cde_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmBiasEPermute_Xdl" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << K1 << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << ABlockTransferDstScalarPerVector_K1 << ", " - << BBlockTransferSrcScalarPerVector << ", " - << BBlockTransferDstScalarPerVector_K1 << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << CBlockTransferScalarPerVector_NWaveNPerXdl - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index af1989fc4a2cc41783806bbe3c2f67718e45e524..514aa4452eb70b9cd548e6d697597f755081948f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm::value) { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + constexpr auto A_K_vec_length = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3); + if(arg.K_raw_ % A_K_vec_length != 0) + { + return false; + } + } + else + { + constexpr auto A_M_vec_lenght = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2); + if(arg.M_raw_ % A_M_vec_lenght != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + constexpr auto B_N_vec_lenght = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2); + if(arg.N_raw_ % B_N_vec_lenght != 0) + { + return false; + } } else { - return false; + constexpr auto B_K_vec_length = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3); + if(arg.K_raw_ % B_K_vec_length != 0) + { + return false; + } + } + + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || + ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); } + return false; } // polymorphic @@ -570,7 +620,7 @@ struct DeviceGemmDl : public DeviceGemm + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..827a341a50357faa164eed6b26ee83fde7802c7e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,766 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD +{ + using DeviceOp = DeviceGemmMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + +#if 0 + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideAs, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideAs)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideBs)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideBs, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } +#endif + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // desc for problem definition + using AsGridDesc_M_K = + remove_cvref_t( + {}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t( + {}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t( + {}, {}, {}))>; + using EGridDesc_M_N = + decltype(GridwiseGemm::template MakeEGridDescriptor_M_N(1, 1, 1)); + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideE)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + GridwiseGemm::template MakeAGridDescriptor_M_K( + MRaw, KRaw, StrideAs[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + GridwiseGemm::template MakeBGridDescriptor_N_K( + KRaw, NRaw, StrideBs[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + // std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl; + // std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl; + // static_for<0, NumDTensor, 1>{}( + //[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + // std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + bool all_valid = true; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + // check vector load of Ds + // only support RowMajor for now + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index 63684693099458263cc34e893464ed42f7dd8295..ad51096db7f6d8aa1cd8b4a8c7590ff385009943 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,8 +50,9 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); @@ -552,7 +553,10 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD( @@ -364,11 +364,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; // We have to separate mean var descriptor for gemm and layernorm bacause of different grid // layout(different padding) - using GemmMeanVarGridDesc_M_NBlock = decltype( - MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + using GemmMeanVarGridDesc_M_NBlock = + decltype(MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); - using GemmCountGridDesc_M_NBlock = decltype( - MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, 1)); + using GemmCountGridDesc_M_NBlock = + decltype(MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); using LayernormMeanVarGridDesc_M_NBlock = decltype(MakeMeanVarDescriptor_M_N, @@ -807,7 +809,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle // workspace for welford intermediate mean workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; - // workspace for welford intermediate mean + // workspace for welford intermediate variance workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; // workspace for welford intermediate count @@ -855,8 +857,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 4c1c3ab7bf1d10d5f8f75f88c41064db1fb51481..916f29a904eba4172b13651249263d085b16f378 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -61,7 +61,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -337,10 +337,12 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -555,8 +557,7 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index b59bb2b3038e06ab698cc930aeacd0cd3a25c5cd..44b3518e2c950e15ec3d5cfd240c443029987622 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 2488101484a6d3674be42afe79b836d61058a58a..d98725cf9d805be928083479b8e78c7b5f8aa997 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,7 +20,8 @@ namespace ck { template (p_a_grid, @@ -143,7 +144,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeDataType = EDataType> struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -438,6 +446,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index a5051455bf18502afd366e04c6ca9cf4b6103826..b008a6409ea1cf6b4da069d095ae672106729178 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -75,132 +75,20 @@ struct DeviceGemmXdl : public DeviceGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) - { - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, MPerBlock, NPerBlock, K0PerBlock, @@ -232,173 +120,41 @@ struct DeviceGemmXdl : public DeviceGemm; - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{K} - { - a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - - if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, - b_grid_desc_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(stream_config.log_level_ > 0) { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + karg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); float ave_time = 0; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } else { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); } return ave_time; @@ -418,7 +174,7 @@ struct DeviceGemmXdl : public DeviceGemm || is_same_v || is_same_v || is_same_v)) @@ -441,15 +198,12 @@ struct DeviceGemmXdl : public DeviceGemm(static_cast(p_a), static_cast(p_b), @@ -511,12 +252,7 @@ struct DeviceGemmXdl : public DeviceGemm + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm{}; static constexpr auto I2 = Number<2>{}; - static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto MPad = M - MRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_right_pad_transform(MRaw, MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - assert(K % AK1 == 0); - - const auto AK0 = K / AK1; - - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - // not pad M or K - assert(KRaw % AK1 == 0); - - const auto AK0 = KRaw / AK1; - - const auto a_grid_desc_ak0_m_ak1 = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(MRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); - } - }(); - - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; - - const auto NPad = N - NRaw; - const auto KPad = K - KRaw; - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(NRaw, NPad), - make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - assert(K % BK1 == 0); - - const auto BK0 = K / BK1; - - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // not pad N or K - assert(KRaw % BK1 == 0); - - const auto BK0 = KRaw / BK1; - - const auto b_grid_desc_bk0_n_bk1 = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(NRaw)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - } - - static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideC)); - } - }(); - - const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; - const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; - - const auto MPad = M - MRaw; - const auto NPad = N - NRaw; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), - make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } - } - - using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); - using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< - ADataType, // TODO: distinguish A/B datatype + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, InMemoryDataOperationEnum::Set, - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, - CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -395,164 +131,47 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm; + PipelineVer, + ComputeTypeA, + ComputeTypeB>; - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, - c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - kraw_{KRaw} - { - if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, - b_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t kraw_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { -#if DEBUG_LOG + if(stream_config.log_level_ > 0) { - std::cout << "arg.a_grid_desc_ak0_m_ak1_{" - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " - << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_bk0_n_bk1_{" - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " - << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + arg.Print(); } -#endif - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; float ave_time = 0; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - true>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); } else { - const auto kernel = kernel_gemm_xdl_cshuffle_v1< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, - typename GridwiseGemm::DefaultBlock2CTileMap, - false>; - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); } return ave_time; @@ -574,25 +193,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, - index_t MRaw, - index_t NRaw, - index_t KRaw, + index_t M, + index_t N, + index_t K, index_t StrideA, index_t StrideB, index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - MRaw, - NRaw, - KRaw, + M, + N, + K, StrideA, StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + StrideC); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index 8ee138f8277295cc2d87c17b9a087c54eed4bc64..a8cd39080cabc3fac597ae481a8d72d668652202 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -648,8 +648,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index 36b01f677f080e4f2f515738af6171d6ed4dbbcd..d54ddf433e49f96c7617de50bb016d427974f4f0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -416,6 +416,11 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm + index_t CBlockTransferScalarPerVector_NWaveNPerXDL, + typename ComputeType = CDataType, + PipelineVersion PipelineVer = PipelineVersion::v1> + struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK + CElementwiseOperation, + ComputeType> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto K1Number = Number{}; - - static auto - MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_m_kpad = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto - MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_kpad_n = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { + // TODO: should be exposed as Tparams. + static constexpr index_t NumGemmKPrefetchStage = 1; + static constexpr LoopScheduler LoopSched = make_default_loop_scheduler(); - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - static auto GetKPad(index_t K, index_t KBatch) - { - const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; - const index_t KPad = KBatch * K0 * K1; - return KPad; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, @@ -251,238 +122,133 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; - // GridwiseGemm - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::AtomicAdd, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXDL, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; - - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - - using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; - - // Argument - struct Argument : public BaseArgument + struct Argument : public GridwiseGemm::Argument { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t k_batch) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - k_batch_{k_batch} + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) { - int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); - - a_grid_desc_kbatch_k0_m_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( - M, K, StrideA, k_batch_, KPad); - b_grid_desc_kbatch_k0_n_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( - K, N, StrideB, k_batch_, KPad); - c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - } } - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t k_batch_; + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; }; + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdlSplitKCShuffle::Argument; - void Print(const Argument& arg) - { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } + void Print(const Argument& karg) { karg.Print(); } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { - Print(arg); + Print(karg); } - const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const auto kbatch = karg.k_batch; - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0 = karg.K0; const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); float ave_time = 0; const auto Run = [&](const auto& kernel) { - hipGetErrorString(hipMemset( - arg.p_c_grid_, - 0, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); ave_time = launch_and_time_kernel(stream_config, kernel, - dim3(grid_size), + dim3(gdx, gdy, gdz), dim3(BlockSize), 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); }; if(has_main_k0_block_loop) { if(kbatch == 1) { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -491,37 +257,27 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -544,12 +300,14 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK"; - // clang-format on - - return str.str(); - } + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8de42ba9efcdc018d0d945ace9aa3e2da9a41cbf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlStreamK : public DeviceGemmStreamK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< + BlockSize, + BlockToCTileMap_GemmStreamK, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + dim3 grid_dims = karg.block_mapping.get_grid_dims(); + + float ave_time = 0; + + const auto kernel = kernel_gemm_xdlops_streamk; + + // TODO: remove clear buffer for streamk kernels + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + ave_time = launch_and_time_kernel(stream_config, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + char* workspace_semaphore = reinterpret_cast(karg.p_workspace_) + + karg.block_mapping.get_workspace_size_for_acc( + sizeof(typename GridwiseGemm::FloatAcc)); + auto preprocess = [&]() { + hipGetErrorString( + hipMemsetAsync(workspace_semaphore, + 0, + karg.block_mapping.get_workspace_size_for_semaphore(), + stream_config.stream_id_)); + }; + + ave_time = launch_and_time_kernel_with_preprocess(stream_config, + preprocess, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* p_arg = dynamic_cast(pArg); + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc)); + } + else + { + return 0; + } + } + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942")) + { + return false; + } + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + uint32_t NumSKBlocks = 0xffffffff) + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + NumSKBlocks}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + index_t NumSKBlocks = 0) override + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(reinterpret_cast(p_a), + reinterpret_cast(p_b), + reinterpret_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + static_cast(NumSKBlocks)); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index af38f142549c6475a7fa983585f0a845286aa16e..1b34e2dba284c6996ccc85b86f424994a1591423 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,7 +48,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -248,10 +248,12 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -417,8 +419,7 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -400,14 +404,18 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; struct GroupedContractionBlock2ETileMap { @@ -652,11 +660,12 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle } } - hipGetErrorString(hipMemcpy(arg.p_workspace_, - arg.contraction_multi_d_kernel_args_.data(), - arg.contraction_multi_d_kernel_args_.size() * - sizeof(ContractionMultiDKernelArg), - hipMemcpyHostToDevice)); + hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, + arg.contraction_multi_d_kernel_args_.data(), + arg.contraction_multi_d_kernel_args_.size() * + sizeof(ContractionMultiDKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); float ave_time = 0; @@ -704,8 +713,7 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e2f9f24bf50080c10b3d440e9b87f85bdcac34dc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,879 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto transform_conv_to_gemm = + TransformConvBwdDataToGemm_v1{}; + + static auto GetDummyABDsEGridDescriptor() + { + const std::array dummy_tensor_lengths = {1}; + const std::array dummy_tensor_strides = {1}; + const std::array dummy_spatial_lengths = {1}; + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return transform_conv_to_gemm.template MakeCDescriptor_M_N( + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + }, + Number{}); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_tensor_lengths, + dummy_tensor_strides, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths, + dummy_spatial_lengths); + + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // desc + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor()); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + a_g_n_k_wos_strides_{a_g_n_k_wos_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths}, + ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + e_g_n_c_wis_strides_{e_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // for check validity + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap( + e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_ctile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_k_wos_lengths_; + std::array a_g_n_k_wos_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_c_wis_lengths_; + std::array, NumDTensor> ds_g_n_c_wis_strides_; + std::array e_g_n_c_wis_lengths_; + std::array e_g_n_c_wis_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_ctile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 92c20a308f9bc3b8c5f98581ce3e627d30420ebb..88b859efbc7a002cc7128534182c707d48a20d0b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -24,51 +25,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -131,7 +87,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -242,7 +198,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename AComputeType = ADataType, + typename BComputeType = AComputeType> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD + CDEElementwiseOp, + AComputeType, + BComputeType> { - // FIXME - static_assert(NDimSpatial == 2, "wrong! only implemented for 2D now"); + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; static constexpr index_t NumDTensor = DsDataType::Size(); - // TODO make A/B datatype different + // TODO: Add support for different A and B data types. using ABDataType = ADataType; static constexpr auto I0 = Number<0>{}; @@ -279,6 +240,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 BK1, MPerBlock, NPerBlock, + KPerBlock, DoPadGemmM, DoPadGemmN>{}; @@ -354,7 +316,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, // TODO: distinguish A/B datatype + ABDataType, + ABDataType, + AComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -394,7 +358,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched>; + LoopSched, + PipelineVersion::v1, + BComputeType>; template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) @@ -421,10 +387,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); // block-to-e-tile map using Block2ETileMap = @@ -459,7 +427,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_k_wos_lengths[0]}, - num_gemm_{}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, cde_element_op_{cde_element_op}, @@ -492,133 +459,172 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; }); + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + // problem definition - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides_[0]; - const index_t ConvStrideW = conv_filter_strides_[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations_[0]; - const index_t ConvDilationW = conv_filter_dilations_[1]; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; - // number of GEMM - num_gemm_ = YTilde * XTilde; - - for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) { - for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) - { - // check slice is valid - const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); - const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - if(YDotSlice * XDotSlice <= 0) + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) { - continue; - } - - const auto a_grid_desc_ak0_m_ak1 = - transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - const auto b_grid_desc_bk0_n_bk1 = - transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - DsGridDesc_M_N ds_grid_desc_m_n; - - // populate Ds desc - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - ds_grid_desc_m_n(i) = - transform_conv_to_gemm.template MakeCDescriptor_M_N( + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + const auto a_grid_desc_ak0_m_ak1 = + transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( a_g_n_k_wos_lengths, a_g_n_k_wos_strides, b_g_k_c_xs_lengths, b_g_k_c_xs_strides, - ds_g_n_c_wis_lengths[i], - ds_g_n_c_wis_strides[i], + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, input_right_pads, - {i_ytilde, i_xtilde}); - }); - - const auto e_grid_desc_m_n = - transform_conv_to_gemm.template MakeCDescriptor_M_N( - a_g_n_k_wos_lengths, - a_g_n_k_wos_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_c_wis_lengths, - e_g_n_c_wis_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - {i_ytilde, i_xtilde}); - - // desc for problem definition - const auto a_grid_desc_m_k = transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); - const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - - a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); - b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); - ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); - e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); - - // desc for blockwise copy - a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); - b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); - - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - block_2_etile_map_container_.push_back(block_2_etile_map); - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); + tildes); - e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n)); + const auto b_grid_desc_bk0_n_bk1 = + transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(i) = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + }); + + const auto e_grid_desc_m_n = + transform_conv_to_gemm.template MakeCDescriptor_M_N( + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } } } } @@ -626,7 +632,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 void Print() const { - for(index_t i = 0; i < num_gemm_; i++) + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) { std::cout << "a_grid_desc_ak0_m_ak1_container_" << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; @@ -654,7 +660,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // tensor descriptor for problem definition index_t num_group_; - index_t num_gemm_; std::vector a_grid_desc_m_k_container_; std::vector b_grid_desc_n_k_container_; std::vector ds_grid_desc_m_n_container_; @@ -708,7 +713,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 float ave_time = 0; - for(index_t i = 0; i < arg.num_gemm_; i++) + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], arg.b_grid_desc_n_k_container_[i], @@ -788,6 +793,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; @@ -807,7 +817,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector load for A matrix from global memory to LDS - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) { if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) { @@ -820,7 +833,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector load for B matrix from global memory to LDS - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v) { if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) { @@ -839,7 +853,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using DLayout = remove_cvref_t>; if constexpr(is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || is_same_v) @@ -862,7 +878,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } // vector store for E - if constexpr(is_same_v) + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) { // vector store C matrix into global memory if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp similarity index 55% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 9529744f31daffb294a27c5a96f6708c22fc3a8a..e2c9dca9047284e39a32803fb7ed3006240b59af 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -12,8 +12,10 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,32 +23,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template {}, integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_kbatch_k0_m0_m1_k1; + ignore = b_grid_desc_kbatch_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; +#endif } template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> +struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight { - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; + // 1d + static constexpr bool is_NWGK_GKXC_NWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNWK_GKXC_GNWC = + is_same_v && + is_same_v && + is_same_v; + // 2d + static constexpr bool is_NHWGK_GKYXC_NHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNHWK_GKYXC_GNHWC = + is_same_v && + is_same_v && + is_same_v; + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + using DeviceOp = DeviceGroupedConvBwdWeight_Dl; using ADataType = OutDataType; using BDataType = InDataType; @@ -176,6 +186,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto spatial_offset = I3; + static constexpr auto K1Number = Number{}; static constexpr auto GemmK1Number = K1Number; @@ -195,104 +207,116 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; - const index_t Wi = input_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - const index_t ConvStrideW = conv_filter_strides[0]; - const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset]; + const index_t InLeftPadW = input_left_pads[I0]; + const index_t InRightPadW = input_right_pads[I0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + const index_t ConvDilationW = conv_filter_dilations[I0]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; const index_t GemmKTotal = N * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weights tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -321,143 +345,156 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1]; + + const index_t InLeftPadH = input_left_pads[I0]; + const index_t InLeftPadW = input_left_pads[I1]; + const index_t InRightPadH = input_right_pads[I0]; + const index_t InRightPadW = input_right_pads[I1]; + const index_t ConvStrideH = conv_filter_strides[I0]; + const index_t ConvStrideW = conv_filter_strides[I1]; + const index_t ConvDilationH = conv_filter_dilations[I0]; + const index_t ConvDilationW = conv_filter_dilations[I1]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -488,151 +525,166 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; - - const index_t Z = filter_spatial_lengths[0]; - const index_t Y = filter_spatial_lengths[1]; - const index_t X = filter_spatial_lengths[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2]; + const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2]; + const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2]; + + const index_t InLeftPadD = input_left_pads[I0]; + const index_t InLeftPadH = input_left_pads[I1]; + const index_t InLeftPadW = input_left_pads[I2]; + const index_t InRightPadD = input_right_pads[I0]; + const index_t InRightPadH = input_right_pads[I1]; + const index_t InRightPadW = input_right_pads[I2]; + const index_t ConvStrideD = conv_filter_strides[I0]; + const index_t ConvStrideH = conv_filter_strides[I1]; + const index_t ConvStrideW = conv_filter_strides[I2]; + const index_t ConvDilationD = conv_filter_dilations[I0]; + const index_t ConvDilationH = conv_filter_dilations[I1]; + const index_t ConvDilationW = conv_filter_dilations[I2]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InDStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * Z * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -672,27 +724,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end @@ -701,22 +758,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); } template ::type = false> static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, @@ -784,17 +841,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -810,13 +866,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl a_element_op_{out_element_op}, b_element_op_{wei_element_op}, c_element_op_{in_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - input_spatial_lengths_{input_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, - output_spatial_lengths_{output_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[I0]}, + Conv_K_{b_g_k_c_xs_lengths[I1]}, + Conv_C_{a_g_n_c_wis_lengths[I2]}, + filter_lengths_{b_g_k_c_xs_lengths}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, @@ -825,12 +878,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -853,24 +906,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - N * K * - std::accumulate(begin(output_spatial_lengths), - end(output_spatial_lengths), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - N * C * - std::accumulate(begin(input_spatial_lengths), - end(input_spatial_lengths), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0]; } const ADataType* p_a_grid_; @@ -889,7 +927,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; // element-wise op OutElementwiseOperation a_element_op_; @@ -897,18 +935,15 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl InElementwiseOperation c_element_op_; // for checking IsSupportedArgument() - index_t Conv_G_; - index_t Conv_N_; - index_t Conv_K_; - index_t Conv_C_; - - std::array input_spatial_lengths_; - std::array filter_spatial_lengths_; - std::array output_spatial_lengths_; - std::array conv_filter_strides_; - std::array conv_filter_dilations_; - std::array input_left_pads_; - std::array input_right_pads_; + const index_t Conv_G_; + const index_t Conv_K_; + const index_t Conv_C_; + + std::array filter_lengths_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; index_t k_batch_; }; @@ -964,7 +999,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop, has_double_loop>; @@ -1026,8 +1061,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static bool IsSupportedArgument(const Argument& arg) { - // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030")) + + // DL version only supports split_k equal to 1 + if(arg.k_batch_ != 1) + return false; + + if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) || + (NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) || + (NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)))) { return false; } @@ -1038,8 +1079,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { - if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && - arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + if(!(arg.filter_lengths_[spatial_offset + i] == 1 && + arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) { return false; } @@ -1106,35 +1148,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1151,17 +1192,16 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -1170,13 +1210,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1197,7 +1236,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" + str << "DeviceGroupedConvBwdWeight_Dl" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8a721eb36b854643fe9a178075d8646dc384940 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template ::type = false> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto GemmK1Number = Number{}; + static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; + const index_t GemmKPad = GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Pad + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto GetABCGridDesc() + { + const index_t dim = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CDataType, + Tuple<>, + CDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + Tuple<>, + CGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})); + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( + CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap( + c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + using EmptyTuple = Tuple<>; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + EmptyTuple{}, // Ds + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{}, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // TODO: Add support for split_k > 1 + if(arg.k_batch_ != 1) + { + return false; + } + + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp similarity index 62% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index c921c9f1b400f76627376d4865498aa27e7b49c3..cbb8643627b9c95891df5e5af7f01266055d4fec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -21,34 +22,9 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -struct ComputePtrOffsetOfStridedBatch -{ - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - index_t BatchStrideC_; -}; - -} // namespace - template (compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); - __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, @@ -126,6 +102,9 @@ __global__ void // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + typename ComputeTypeA = InDataType, + typename ComputeTypeB = ComputeTypeA> +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight { - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle; + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; using ADataType = OutDataType; using BDataType = InDataType; @@ -196,6 +169,30 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // TODO make A/B datatype different using ABDataType = InDataType; + // 1d + static constexpr bool is_GNWK_GKXC_GNWC = + is_same_v && + is_same_v && + is_same_v; + // 2d + static constexpr bool is_NHWGK_GKYXC_NHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNHWK_GKYXC_GNHWC = + is_same_v && + is_same_v && + is_same_v; + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -220,19 +217,132 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; static constexpr auto BBlockLdsN1Padding = 4; + template ::type = false> + constexpr static auto + make_out_grid_desc(const ck::index_t N, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const ck::index_t N, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const ck::index_t N, + const ck::index_t Do, + const ck::index_t Ho, + const ck::index_t Wo, + const ck::index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const ck::index_t N, + const ck::index_t Di, + const ck::index_t Hi, + const ck::index_t Wi, + const ck::index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& /* input_strides */, + const std::array& /* weights_strides */, + const std::array& /* output_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -248,6 +358,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * X; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * @@ -282,14 +395,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( in_gemmktotal_gemmn_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -366,25 +479,56 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -413,21 +557,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * X * Y; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -441,41 +589,29 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, + in_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - // A: output tensor const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -490,7 +626,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // B: input tensor const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, + in_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), @@ -529,29 +665,56 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t batch_k) + const ck::index_t N, + const ck::index_t K, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) { using namespace ck; @@ -587,21 +750,25 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const index_t GemmM = K; const index_t GemmN = C * Z * X * Y; + const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock; + const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock; + const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -615,41 +782,29 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, + in_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( in_gemmkpad_gemmn_grid_desc, make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + make_pass_through_transform(GemmN)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - // A: output tensor const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, + out_grid_desc, make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), make_pass_through_transform(GemmM)), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -664,7 +819,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle // B: input tensor const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_grid_desc, + in_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Di, InLeftPadD, InRightPadD), make_pad_transform(Hi, InLeftPadH, InRightPadH), @@ -712,44 +867,110 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); } } // function end template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - 1); + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); } // type convert descs @@ -804,7 +1025,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, - ADataType, // TODO: distinguish A/B datatype + ADataType, + BDataType, AccDataType, CDataType, InMemoryDataOperationEnum::AtomicAdd, @@ -849,7 +1071,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle CBlockTransferScalarPerVector_NWaveNPerXdl, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, true, - true>; + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = @@ -863,19 +1089,18 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - ck::index_t M01, - ck::index_t N01, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, @@ -894,25 +1119,40 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle a_element_op_{out_element_op}, b_element_op_{in_element_op}, c_element_op_{wei_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - output_spatial_lengths_{output_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -927,22 +1167,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - N * K * - std::accumulate(begin(output_spatial_lengths), - end(output_spatial_lengths), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - N * C * - std::accumulate(begin(input_spatial_lengths), - end(input_spatial_lengths), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); @@ -967,26 +1197,27 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle Block2CTileMap block_2_ctile_map_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; index_t M01_; index_t N01_; - InElementwiseOperation a_element_op_; - OutElementwiseOperation b_element_op_; + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; WeiElementwiseOperation c_element_op_; // for checking IsSupportedArgument() - index_t Conv_G_; - index_t Conv_N_; - index_t Conv_K_; - index_t Conv_C_; - std::array output_spatial_lengths_; + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; std::array filter_spatial_lengths_; - std::array conv_filter_strides_; - std::array input_left_pads_; - std::array input_right_pads_; - index_t k_batch_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; }; // Invoker @@ -1035,7 +1266,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype + ADataType, + BDataType, CDataType, OutElementwiseOperation, InElementwiseOperation, @@ -1044,7 +1276,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle remove_reference_t, remove_reference_t, remove_reference_t, - ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, has_main_loop>; return launch_and_time_kernel(stream_config, @@ -1091,6 +1323,36 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWK_GKXC_GNWC) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)) + { + return false; + } + } + else + { + return false; + } + if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { @@ -1131,35 +1393,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1178,32 +1439,30 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::array input_spatial_lengths, - std::array filter_spatial_lengths, - std::array output_spatial_lengths, - std::array conv_filter_strides, - std::array conv_filter_dilations, - std::array input_left_pads, - std::array input_right_pads, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, - ck::index_t split_k) override + const ck::index_t split_k) override { return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1226,7 +1485,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle" + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp similarity index 96% rename from include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 66b47abb624328f31984e80c5d32811d54bbc4c2..8cfdd04e55654f2d71d40304d43a200fee821a26 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -134,8 +90,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ - defined(__gfx90a__) || defined(__gfx908__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -380,8 +337,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK } // desc for problem definition - using AGridDesc_AK0_M_AK1 = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; @@ -710,7 +667,10 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK // check device if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908")) + ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx908" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" || + ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102" || + ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942")) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index c77772faa4c060dbeceba009e9a742be4d592a6a..f18fbcfe4b51031009f6329baac2bc1ac209e1fa 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -106,7 +106,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ + defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -319,8 +320,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t({}, {}))>; using CGridDesc_M_N = remove_cvref_t({}, {}))>; @@ -600,7 +601,9 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using EGridDesc_M_N = remove_cvref_t({}, {}))>; using RGridDesc_M = remove_cvref_t({}, {}))>; @@ -507,10 +507,12 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle RThreadTransferDstScalarPerVector_MPerBlock, LoopSched>; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -811,7 +813,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 9d4b68c0b6429aacd16106d1088db5e1c1f74d4f..3e303f5c5ea3989f32871ae7586f68b0300e234f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -27,55 +28,6 @@ namespace ck { namespace tensor_operation { namespace device { -namespace { - -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - -} // namespace - // // @brief Device Convolution operation. // @@ -245,8 +197,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; @@ -519,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< GridwiseOp, ADataType, BDataType, @@ -599,7 +551,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -616,7 +568,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 02458bf02a39870118b846bf99fc774d5d73ee4c..f4b8d66ecf4041a3f544a63dc2d70ec979fb65bc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -19,6 +19,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/io.hpp" @@ -29,51 +30,6 @@ namespace device { namespace { -template -struct ComputePtrOffsetOfStridedBatch -{ - ComputePtrOffsetOfStridedBatch() = default; - - ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB, - Array BatchStrideDs, - index_t BatchStrideE) - : BatchStrideA_(BatchStrideA), - BatchStrideB_(BatchStrideB), - BatchStrideDs_(BatchStrideDs), - BatchStrideE_(BatchStrideE) - { - } - - __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB_); - } - - __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const - { - Array ds_offset; - static_for<0, NumDTensor, 1>{}( - [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); - return ds_offset; - } - - __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE_); - } - - index_t BatchStrideA_; - index_t BatchStrideB_; - Array BatchStrideDs_; - index_t BatchStrideE_; -}; - /* * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. * @@ -136,7 +92,7 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -255,7 +211,8 @@ template + typename ComputeDataType = ADataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleD + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle; @@ -361,8 +319,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using AGridDesc_M_K = remove_cvref_t( + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; using BGridDesc_N_K = remove_cvref_t({}, {}))>; using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; @@ -370,6 +328,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -412,14 +372,18 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle LoopSched>; // desc for blockwise copy - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // block-to-e-tile map using Block2ETileMap = @@ -685,7 +649,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle return false; } } - else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940") + else if(get_device_name() == "gfx90a" || get_device_name() == "gfx940" || + get_device_name() == "gfx941" || get_device_name() == "gfx942") { if constexpr(!(is_same_v || is_same_v || is_same_v || is_same_v)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e19a00299e5affa1618a1e4d0fbc44a4f2038b41 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index d424f2992076cc32704631c3badbefa3c88dec2c..0190b3cee6b44af7656c34be866c9ab0855f7b0d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -1,6 +1,6 @@ #pragma once // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -39,8 +39,9 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx1030__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); @@ -596,10 +597,12 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, AccDataType, CShuffleDataType, DsDataType, @@ -272,14 +276,18 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; - using AGridDesc_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BK0_N_BK1 = remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; struct GroupedGemmBlock2ETileMap { @@ -548,11 +556,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(arg.gemm_desc_kernel_arg_.size()) + arg.skipped_group_count_) != arg.group_count_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..56132f7a0f8a9605549857787900463f8ab0a6da --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -0,0 +1,839 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M * N * K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + if(!GridwiseGemm:: + template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * sizeof(GroupedGemmKernelArgument); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..41445aeaf076408adc7159359884803f0965e313 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -0,0 +1,632 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].karg_, + static_cast(p_shared), + gemm_desc_ptr[group_id].block_2_ctile_map_); +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template > && + is_same_v>, + bool> = false> +struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + EDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer>; + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + KernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(KernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0 = GridwiseGemm::CalculateK0(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_c, + m_padded, + n_padded, + k_padded, + k0, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0 = GridwiseGemm::CalculateK0(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KPadded = k_padded; + karg.K0 = k0; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; + bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + K0 = karg.K0; + bool not_all_have_main_k0_block_loop_same = + all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << kbatch + << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch + << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + hip_check_error( + hipMemcpyWithStream(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size()); + }; + + if(all_have_main_k0_block_loop) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + else + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + if(not group_arg_valid) + { +#if DEBUG_LOG + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); +#endif // DEBUG_LOG + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_XdlSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c83ffdcd2698489df4504c60394043d96aa02da6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Image to column: +// input : input image [G, N, Di, Hi, Wi, C] +// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C] +// input : input image [N, Di, Hi, Wi, G, C] +// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceImageToColumnImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + constexpr index_t spatial_offset = 3; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K( + a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& gemm_g_m_k_strides) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2])); + return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = + GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + + in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + out_grid_desc_m_k_ = MakeOutDescriptor_M_K( + N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides); + + compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0]; + } + + void Print() const + { + std::cout << in_grid_desc_m_k_ << std::endl; + std::cout << out_grid_desc_m_k_ << std::endl; + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + InputGridDesc in_grid_desc_m_k_; + OutputGridDesc out_grid_desc_m_k_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_, + arg.p_in_, + arg.out_grid_desc_m_k_, + arg.p_out_, + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_, + arg.out_grid_desc_m_k_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceImageToColumn" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e98a85defe95aabed5c977250048cae117b701d6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd +{ + using DInDataType_AutomicAddPreCast = + conditional_t || is_same_v, + DInDataType, + float>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; + + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t loop_step) + { + const auto m = desc_m.GetLength(I0); + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t loop_step) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, loop_step); + } + + using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); + + using GridwisePutElementSet = GridwisePutElement_1D; + + using GridwisePutElementAtomicAdd = GridwisePutElement_1D; + + using GridwiseCasting = GridwiseElementwise_1D, + Tuple, + Tuple, + Tuple, + UnaryConvert, + InOutVectorSize, + Sequence, + Sequence>; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + const IndexDataType* p_indices, + DInDataType* p_din, + index_t dout_length, + index_t din_length, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations) + : p_dout_{p_dout}, + p_indices_{p_indices}, + p_din_{p_din}, + dout_length_raw_{dout_length}, + din_length_raw_{din_length}, + blockSize_{256}, + windowOverlap_{false} + { + for(size_t i = 0; i < window_lengths.size(); ++i) + { + auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1; + windowOverlap_ |= eff > window_strides.at(i); + } + } + + const DOutDataType* p_dout_; + const IndexDataType* p_indices_; + DInDataType* p_din_; + index_t dout_length_raw_; + index_t din_length_raw_; + index_t blockSize_; + bool windowOverlap_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; + InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); + InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); + + if constexpr(is_same_v || is_same_v) + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + if(arg.windowOverlap_) + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + else + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + else + { + if(arg.windowOverlap_) + { + if(arg.p_workspace_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + hip_check_error( + hipMemsetAsync(arg.p_workspace_, + 0, + arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + const auto cast_kernel = + kernel_elementwise_1d, + Tuple, + Tuple, + Tuple, + UnaryConvert>; + + float elapsed_time = launch_and_time_kernel( + stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + static_cast(arg.p_workspace_), + PassThrough{}); + + elapsed_time += launch_and_time_kernel( + stream_config, + cast_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + ck::make_tuple(din_grid_desc), + ck::make_tuple(din_grid_desc), + static_cast(arg.p_workspace_), + arg.p_din_, + UnaryConvert{}); + + return elapsed_time; + } + else + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + bool needCast = pArg_->windowOverlap_ && + !(is_same_v || is_same_v); + + if(!needCast) + return 0; + else + return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + if(pArg->din_length_raw_ % InOutVectorSize != 0 || + pArg->dout_length_raw_ % InOutVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) override + { + // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are + // physical size of the packed tensor + return std::make_unique(static_cast(p_dout), + static_cast(p_indices), + static_cast(p_din), + dout_length, + din_length, + window_lengths, + window_strides, + window_dilations); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp index b49e1096829695f7f7d5b0f651a20f7aebc03f6b..aec5a65ccf41eafcbe6ebdf790cc75b0adbf2966 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp index 17a96e9f6f6c8039485afcab82d5908d92a59a24..6d1d5c8e239de8f57df1baa82d4981e8f30e54b6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index bb62332d1ad54eca7fbb8d02f4b511b989530426..1ef33501859fad0810f8209f981d8f8717d5da41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -10,8 +10,7 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -20,11 +19,16 @@ namespace tensor_operation { namespace device { // Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W template struct DeviceNormalizationImpl : public DeviceNormalization @@ -61,19 +66,24 @@ struct DeviceNormalizationImpl : public DeviceNormalization& inLengths, const std::vector& inStrides, - int blkGroupSize, int numBlockTileIteration) { - constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -117,10 +127,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; const auto inPad_M = math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; auto in_grid_desc_m_k_padded = transform_tensor_descriptor( in_grid_desc_m_k, @@ -132,7 +141,37 @@ struct DeviceNormalizationImpl : public DeviceNormalization& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); struct Argument : public BaseArgument { @@ -141,17 +180,23 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, const BetaDataType* p_beta, - YDataType* p_y) + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) : p_x_(p_x), p_gamma_(p_gamma), p_beta_(p_beta), p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), y_elementwise_op_(y_elementwise_op) { epsilon_ = static_cast(epsilon); @@ -161,30 +206,31 @@ struct DeviceNormalizationImpl : public DeviceNormalization(yStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; - long_index_t invariant_total_length; - long_index_t reduce_total_length; + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); - std::tie(invariant_total_length, reduce_total_length) = - get_2d_lengths(Lengths_); + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); - blkGroupSize_ = 1; - numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); - gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize * blkGroupSize_; - - x_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); gamma_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); beta_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); - y_grid_desc_m_k_ = - MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); + y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); isSweeponce_ = x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; } ComputeDataType epsilon_; @@ -193,16 +239,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization Lengths_; std::vector xStrides_; std::vector gammaStrides_; std::vector betaStrides_; std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; YElementwiseOperation y_elementwise_op_; - int blkGroupSize_; int numBlockTileIteration_; size_t gridSize_; @@ -210,7 +259,14 @@ struct DeviceNormalizationImpl : public DeviceNormalization(arg.isSweeponce_); float avg_time = 0; @@ -249,12 +308,16 @@ struct DeviceNormalizationImpl : public DeviceNormalization(p_arg); - constexpr index_t NumInvariantDim = Rank - NumReduceDim; - if constexpr(XYSrcVectorDim == 0) { if constexpr(NumInvariantDim == 0) @@ -281,10 +342,15 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length_); + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) return false; - if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) return false; }; } @@ -295,12 +361,12 @@ struct DeviceNormalizationImpl : public DeviceNormalizationLengths_[Rank - 1] % XSrcVectorSize != 0) return false; - }; - if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) - { - return false; - } + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + }; // if fastest dim is not reduced if constexpr(GammaSrcVectorDim == 0) @@ -326,7 +392,7 @@ struct DeviceNormalizationImpl : public DeviceNormalizationbetaStrides_[NumInvariantDim - 1] != 1) return (false); - if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) return (false); } else // if fastest dim is reduced @@ -338,6 +404,9 @@ struct DeviceNormalizationImpl : public DeviceNormalizationinvariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + return true; }; @@ -347,6 +416,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization gammaStrides, const std::vector betaStrides, const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, const std::vector reduceDims, double epsilon, const void* p_x, @@ -354,27 +425,30 @@ struct DeviceNormalizationImpl : public DeviceNormalization(lengths, xStrides, gammaStrides, betaStrides, yStrides, + saveMeanStrides, + saveInvStdStrides, reduceDims, y_elementwise_op, epsilon, static_cast(p_x), static_cast(p_gamma), static_cast(p_beta), - static_cast(p_y)); + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); }; std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e969e493c6095e4da388044f94d73d5262a615e4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp @@ -0,0 +1,748 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + WorkspaceMeanVarDataType* const __restrict__ p_welford_mean, + WorkspaceMeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count) +{ + GridwiseWelford::Run(x_grid_desc_m_k, + mean_var_grid_desc_m_kblock, + num_k_block_tile_iteration, + p_x_global, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +__global__ void +kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const WorkspaceMeanVarDataType* const p_mean_global, + const WorkspaceMeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, + count_grid_desc_m_kblock, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, + num_k_mean_var_count_iteration, + num_k_block_tile_iteration, + k_grid_size, + epsilon, + p_mean_global, + p_variance_global, + p_welford_count_global, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + p_save_mean_global, + p_save_inv_std_global, + y_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invarient length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationSplitKImpl : public DeviceNormalization +{ + using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int kBlockSize, + int numBlockTileIteration) + { + static constexpr index_t numSrcDim = Rank; + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + template + static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using Kernel1MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2CountGridDesc_M_KBlock = + decltype(MakeWorkspaceCountDescriptor_M_K, 1, 1>(1, 1)); + + using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); + + using GridwiseWelford = GridwiseNormalizationSplitK1st; + + using GridwiseWelfordNormalization = + GridwiseNormalizationSplitK2nd; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = 1; + while(true) + { + int testKGridSize = + math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + + // we want the kGridSize_ be not more than 128 + if(testKGridSize <= 128) + break; + + ++numBlockTileIteration_; + }; + + kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; + + // We do not use vector load for mean, var and count + static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; + + numMeanVarCountIteration_ = + math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); + + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); + + // We don't need to pad in K dimension for Welford1. Set KPerTile 1. + kernel1_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, M_BlockTileSize, 1>( + MRaw_, kGridSize_); + + kernel2_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + kernel2_count_grid_desc_m_kblock_ = + MakeWorkspaceCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + SaveMeanInvStdDataType* p_saveMean_; + SaveMeanInvStdDataType* p_saveInvStd_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; + + YElementwiseOperation y_elementwise_op_; + + int kGridSize_; + int numMeanVarCountIteration_; + int numBlockTileIteration_; + size_t gridSize_; + + SrcGridDesc_M_K x_grid_desc_m_k_; + SrcGridDesc_M_K gamma_grid_desc_m_k_; + SrcGridDesc_M_K beta_grid_desc_m_k_; + SrcGridDesc_M_K y_grid_desc_m_k_; + SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_; + SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_; + + Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; + Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; + Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; + + index_t MRaw_; // invarient length + index_t KRaw_; // reduce length + + index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || + arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + auto kernel1 = kernel_normalizationSplitK1st; + + auto kernel2 = kernel_normalizationSplitK2nd; + + float avg_time = 0; + avg_time += launch_and_time_kernel( + stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); + + return avg_time; + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // workspace for welford intermediate mean + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if constexpr(XYVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return false; + } + + if(p_arg_->kGridSize_ <= 1) + return false; + + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvStd, + YElementwiseOperation y_elementwise_op) override + { + if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank || + betaStrides.size() != Rank || yStrides.size() != Rank || + saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + saveMeanStrides, + saveInvStdStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationSplitKImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp index 7b96373c0ff1fc86a79c4248db35dd5812f2c5b1..17dab08332166e23a93d5c8f25ecfec7556519ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp index bfde40cda2a5369015b08de568f72c786295848f..c94c568c49aca819997cd8a1d32c536d7527a551 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -1,18 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include -#include - -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" -#include "ck/tensor_operation/gpu/device/device_pool2d_fwd.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp" namespace ck { namespace tensor_operation { @@ -20,305 +11,100 @@ namespace device { template -struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd +struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - - using IndexDataType = int32_t; - - using ReduceOperation = typename reduce_binary_operator::opType; - - using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; - - using AccElementwiseOperation = - typename reduce_unary_operator::AccElementwiseOperation; - - static constexpr index_t InSrcOutDstVectorDim = - 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is - // not reduced. - - static constexpr ck::index_t ReduceM_BlockTileSize = - ReduceMThreadClusterSize * ReduceMThreadSliceSize; - static constexpr ck::index_t ReduceK_BlockTileSize = - ReduceKThreadClusterSize * ReduceKThreadSliceSize; - - static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N, - ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) - { - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = window_spatial_lengths[0]; - const index_t X = window_spatial_lengths[1]; - - const index_t ConvStrideH = window_strides[0]; - const index_t ConvStrideW = window_strides[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ReduceMRaw = N * Ho * Wo * C; - const index_t ReduceMPad = - math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw; - - const index_t ReduceKRaw = Y * X; - const index_t ReduceKPad = - math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw; - - // A[ReduceM, ReduceK] - const auto in_grid_desc_n_hi_wi_c = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( - in_grid_desc_n_hi_wi_c, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( - in_grid_desc_n_hip_wip_c, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_grid_desc_reducemraw_reducekraw = - transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), - make_merge_transform(make_tuple(Y, X))), - make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( - in_grid_desc_reducemraw_reducekraw, - make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad), - make_right_pad_transform(ReduceKRaw, ReduceKPad)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // B[ReduceM] - const auto out_grid_desc_reducemraw = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C)); - - const auto out_grid_desc_reducem = transform_tensor_descriptor( - out_grid_desc_reducemraw, - make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - - return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); - } - - using ABGridDescs = decltype( - MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - - using AGridDesc_M_K = remove_cvref_t; - using BGridDesc_M = remove_cvref_t; - - // TODO - struct Argument : public BaseArgument - { - Argument(const InDataType* p_in_dev, - OutDataType* p_out_dev, - int* p_out_indices_dev, - ck::index_t N, - ck::index_t C, - std::array& input_spatial_lengths, - std::array& window_spatial_lengths, - std::array& output_spatial_lengths, - std::array& window_strides, - std::array& input_left_pads, - std::array& input_right_pads) - : p_in_dev_{p_in_dev}, - p_out_dev_{p_out_dev}, - p_out_indices_dev_{p_out_indices_dev}, - a_grid_desc_m_k_{}, - b_grid_desc_m_{} - { - const auto descs = MakeABGridDescriptor_A_M_K_B_M(N, - C, - input_spatial_lengths, - window_spatial_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - - a_grid_desc_m_k_ = descs[I0]; - b_grid_desc_m_ = descs[I1]; - - invariant_lowest_length_ = C; - reduce_lowest_length_ = window_spatial_lengths[1]; - - int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; - - std::tie(in_element_op_, acc_element_op_) = - reduce_unary_operator::GetElementwiseOperator(reduceLength); - } - - const InDataType* p_in_dev_; - OutDataType* p_out_dev_; - int* p_out_indices_dev_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_M b_grid_desc_m_; - InElementwiseOperation in_element_op_; - AccElementwiseOperation acc_element_op_; - - // for checking vector load/store - ck::index_t invariant_lowest_length_; - ck::index_t reduce_lowest_length_; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - using gridwise_reduce = - GridwiseReduction_mk_to_m_threadwise; - const auto kernel = kernel_reduce_threadwise; - - ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0); - - const index_t grid_size = (ReduceM / ReduceM_BlockTileSize); - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.a_grid_desc_m_k_, - arg.b_grid_desc_m_, - arg.in_element_op_, - arg.acc_element_op_, - float(1), - arg.p_in_dev_, - nullptr, - float(0), - arg.p_out_dev_, - arg.p_out_indices_dev_); - } - - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0) - { - return (false); - } - - return (true); - } - std::unique_ptr MakeArgumentPointer(const void* p_in_dev, void* p_out_dev, void* p_out_indices_dev, - ck::index_t N, - ck::index_t C, - std::array input_spatial_lengths, - std::array window_spatial_lengths, - std::array output_spatial_lengths, - std::array window_strides, - std::array input_left_pads, - std::array input_right_pads) override + std::vector input_lengths, + std::vector window_lengths, + std::vector output_lengths, + std::vector input_stride, + std::vector output_stride, + std::vector indices_stride, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) override { - return std::make_unique(static_cast(p_in_dev), - static_cast(p_out_dev), - static_cast(p_out_indices_dev), - N, - C, - input_spatial_lengths, - window_spatial_lengths, - output_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ","; - str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ","; - str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ","; - str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; - // clang-format on - - return str.str(); + static constexpr index_t InOutRank = 4; + static constexpr index_t WindowRank = 2; + + if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || + input_lengths.size() != InOutRank || window_strides.size() != WindowRank || + window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank || + input_right_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3}) + throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); + + // NCHW to NCDHW + input_lengths.insert(input_lengths.begin() + 2, 1); + output_lengths.insert(output_lengths.begin() + 2, 1); + input_stride.insert(input_stride.begin() + 2, 0); + output_stride.insert(output_stride.begin() + 2, 0); + indices_stride.insert(indices_stride.begin() + 2, 0); + + // YX to ZYX + window_lengths.insert(window_lengths.begin(), 1); + window_strides.insert(window_strides.begin(), 0); + window_dilations.insert(window_dilations.begin(), 0); + input_left_pads.insert(input_left_pads.begin(), 0); + input_right_pads.insert(input_right_pads.begin(), 0); + + pooling_dims = {2, 3, 4}; + + return DevicePool3D::MakeArgumentPointer(p_in_dev, + p_out_dev, + p_out_indices_dev, + input_lengths, + window_lengths, + output_lengths, + input_stride, + output_stride, + indices_stride, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + pooling_dims); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..384805a0a3de7f8cb083daaf1588e8630d1ad6a9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5, + 3, + InDataType, + OutDataType, + IndexDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC, + ReduceOpId, + OutputIndex> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr index_t InOutRank = 5; + static constexpr index_t WindowRank = 3; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(std::vector input_ncdhw_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector window_spatial_zyx_lengths, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads) + { + const index_t N = input_ncdhw_lengths[0]; + const index_t C = input_ncdhw_lengths[1]; + const index_t Di = input_ncdhw_lengths[2]; + const index_t Hi = input_ncdhw_lengths[3]; + const index_t Wi = input_ncdhw_lengths[4]; + + const index_t Do = output_ncdhw_lengths[2]; + const index_t Ho = output_ncdhw_lengths[3]; + const index_t Wo = output_ncdhw_lengths[4]; + + const index_t Z = window_spatial_zyx_lengths[0]; + const index_t Y = window_spatial_zyx_lengths[1]; + const index_t X = window_spatial_zyx_lengths[2]; + + const index_t WindowStrideD = window_zyx_strides[0]; + const index_t WindowStrideH = window_zyx_strides[1]; + const index_t WindowStrideW = window_zyx_strides[2]; + + const index_t WindowDilationD = window_zyx_dilations[0]; + const index_t WindowDilationH = window_zyx_dilations[1]; + const index_t WindowDilationW = window_zyx_dilations[2]; + + const index_t InLeftPadD = input_left_dhw_pads[0]; + const index_t InLeftPadH = input_left_dhw_pads[1]; + const index_t InLeftPadW = input_left_dhw_pads[2]; + + const index_t InRightPadD = input_right_dhw_pads[0]; + const index_t InRightPadH = input_right_dhw_pads[1]; + const index_t InRightPadW = input_right_dhw_pads[2]; + + const index_t MRaw = N * Do * Ho * Wo * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = Z * Y * X; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // A[ReduceM, ReduceK] + const index_t Ni_stride = input_ncdhw_stride[0]; + const index_t Ci_stride = input_ncdhw_stride[1]; + const index_t Di_stride = input_ncdhw_stride[2]; + const index_t Hi_stride = input_ncdhw_stride[3]; + const index_t Wi_stride = input_ncdhw_stride[4]; + + const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride)); + + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)), + make_merge_transform(make_tuple(Z, Y, X))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const index_t No_stride = output_ncdhw_stride[0]; + const index_t Co_stride = output_ncdhw_stride[1]; + const index_t Do_stride = output_ncdhw_stride[2]; + const index_t Ho_stride = output_ncdhw_stride[3]; + const index_t Wo_stride = output_ncdhw_stride[4]; + + const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride)); + + const auto out_grid_desc_reducemraw = transform_tensor_descriptor( + out_grid_desc_n_do_ho_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto out_grid_desc_reducem = + transform_tensor_descriptor(out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = + decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + IndexDataType* p_out_indices_dev, + std::vector& input_ncdhw_lengths, + std::vector& output_ncdhw_lengths, + std::vector& input_ncdhw_stride, + std::vector& output_ncdhw_stride, + std::vector&, // indices_ncdhw_stride + std::vector& window_spatial_zyx_lengths, + std::vector& window_zyx_strides, + std::vector& window_zyx_dilations, + std::vector& input_left_dhw_pads, + std::vector& input_right_dhw_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{}, + input_ncdhw_lengths_{input_ncdhw_lengths}, + output_ncdhw_lengths_{output_ncdhw_lengths}, + input_ncdhw_stride_{input_ncdhw_stride}, + output_ncdhw_stride_{output_ncdhw_stride} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + window_spatial_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] * + window_spatial_zyx_lengths[2]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + IndexDataType* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + std::vector input_ncdhw_lengths_; + std::vector output_ncdhw_lengths_; + std::vector input_ncdhw_stride_; + std::vector output_ncdhw_stride_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K + + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = + kernel_reduce_threadwise; + + ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (M / M_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + // C should be fastest dimension + if(pArg->input_ncdhw_stride_[1] != 1) + return false; + + for(int i = 0; i < InOutRank; ++i) + { + if(pArg->input_ncdhw_stride_[i] == 1 && + pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + + if(pArg->output_ncdhw_stride_[i] == 1 && + pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_ncdhw_lengths, + std::vector window_zyx_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector indices_ncdhw_stride, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads, + std::vector pooling_dims) override + { + if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank || + input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank || + window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank || + input_right_dhw_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3, 4}) + throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); + + if(output_ncdhw_stride != indices_ncdhw_stride) + throw std::runtime_error( + "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now"); + + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + indices_ncdhw_stride, + window_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7334da0e35af5ce92665decc4897ebe48393b9c3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_put_element.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElementImpl + : public DevicePutElement +{ + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * InVectorSize; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1)); + + using GridwisePutElement = GridwisePutElement_1D; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_input, + const IndexDataType* p_indices, + OutDataType* p_output, + index_t input_length, + ElementwiseOperation elementwise_op) + : p_input_{p_input}, + p_indices_{p_indices}, + p_output_{p_output}, + input_length_raw_{input_length}, + elementwise_op_{elementwise_op}, + blockSize_{256} + { + } + + const InDataType* p_input_; + const IndexDataType* p_indices_; + OutDataType* p_output_; + index_t input_length_raw_; + ElementwiseOperation elementwise_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + InGrid1dDesc in_grid_desc = + MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_); + + const auto kernel = kernel_put_element_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_desc, + arg.p_input_, + arg.p_indices_, + arg.p_output_, + arg.elementwise_op_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->input_length_raw_ % InVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(static_cast(p_input), + static_cast(p_indices), + static_cast(p_output), + input_length, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp index 5dc051be3cb7f23fef076921e826a7e55e017d76..2481c5c76971db6bb2ddc390510089af5f656ae6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp index c7868537fe8220c640ba24d346eea5dc219490fa..bf3deeb57acbdff95995f5da096ab574140bcb42 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp index a1d976f1a17c698c538b0f7ba1b2b200162e8eb2..6c5895b010659cafb98e4a78bc6587d0222e02bd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -28,6 +28,7 @@ template + Rank, + NumReduceDim> { - static constexpr index_t kRank = Rank; - static constexpr index_t kNumReduceDim = NumReduceDim; - static constexpr index_t kNumInvariantDim = Rank - NumReduceDim; - - virtual index_t GetRank() const override { return kRank; } - - virtual index_t GetNumReduceDim() const override { return kNumReduceDim; } - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumSrcDim = Rank; @@ -287,13 +280,13 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) + if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp index 2f29224a754057430cd68bb551807226d40a6cb5..7a62ec04650a94d1b9f458e425ddf3b14881091d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp similarity index 98% rename from include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 70990e795cc7ea594a94aee24737fc3dfc6ba664..cfab7d29c9002f4fc61b154e127c4f09ab9005c9 100644 --- a/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,7 +57,7 @@ __global__ void const Block2ETileMap block_2_etile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = @@ -617,10 +617,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using AGridDesc_AKB_AK0_M_AK1 = remove_cvref_t; - using BGridDesc_BKB_BK0_N_BK1 = remove_cvref_t; + using AGridDesc_AKB_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BKB_BK0_N_BK1 = + remove_cvref_t; using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; @@ -886,11 +888,12 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, has_main_loop>; - hipGetErrorString(hipMemset( + hipGetErrorString(hipMemsetAsync( arg.p_e_grid_, 0, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(EDataType))); + sizeof(EDataType), + stream_config.stream_id_)); return launch_and_time_kernel(stream_config, kernel, @@ -939,8 +942,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index ea0f5897a757c408047b20e0cd4285e13f67c4f5..d6d6f74abdb5a63ef8a841d033c9f47ec59bd07e 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp index 70e61bc7728f9638e5bef874ab0d5caad2db5028..c66d2fc516babe3aa5a16de053d9ada89eabf8ef 100644 --- a/include/ck/tensor_operation/gpu/device/matrix_padder.hpp +++ b/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index d35318357a9bcc6fb087c92e27dcf435ff60796b..5351d4ef2439307ad5623f4172217937812ab29a 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index b44427411f9a28511fd16f6b47ee37fa84a0e2d3..b2d141fd61aa4550ef8b572e548688120634b8b9 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp index 0ec0df2c9bbbb2aac987d681a36d5b00a8349c97..713fc93ebb4eec9ce0301affd667656bf0b6c999 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/include/ck/tensor_operation/gpu/device/welford_helper.hpp index 6c909b767d4602914660d045b9aaf5a2307ea57b..d7772d8764ff3cc66572775f750997fae944800e 100644 --- a/include/ck/tensor_operation/gpu/device/welford_helper.hpp +++ b/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 136017c6d17027a67fee3b493a20a8d3fabe2bd4..9fe0931cbae44250295720f090bda56588d3fb15 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,6 +36,13 @@ struct Add y = x0 + type_convert(x1); }; + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const float& x0, const half_t& x1) const @@ -179,6 +186,13 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + { + y = type_convert(x0 + ck::type_convert(x1)); + }; + float alpha_; float beta_; }; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index ceb2b665b91529f0a97cfb1209b2de40744ee466..9f5ed6adea1915a5a1ea2f67ba65b7641db22855 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -195,6 +195,51 @@ struct AddMultiply } }; +// C = A * B +// E = C x D0 + D1 +struct MultiplyAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c * d0) + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = c * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const float& d0, + const float& d1) const + { + const float y = c * d0 + d1; + e = y; + } +}; + // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 2987def02a62338dc9566727588c332513398fef..af48c6c1afdb1d688702b108a1f78b107274efd4 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/data_type.hpp" #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" +#include "ck/utility/type_convert.hpp" namespace ck { namespace tensor_operation { @@ -26,6 +27,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -38,6 +45,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const float& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { @@ -56,18 +69,42 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const half_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const int8_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const @@ -75,6 +112,66 @@ struct PassThrough y = x; } #endif + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const half_t& x) const + { + y = ck::type_convert(x); + } }; struct UnaryConvert @@ -86,6 +183,41 @@ struct UnaryConvert } }; +struct ConvertBF16RTN +{ + // convert to bf16 using round to nearest (rtn) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = bf16_convert_rtn(x); + } +}; + +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + struct Scale { __host__ __device__ Scale(float scale) : scale_(scale) {} @@ -311,10 +443,11 @@ struct Sigmoid __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); - - y = 1 / (ck::type_convert(1) + exp(-x)); + constexpr T one = type_convert(1); + y = one / (one + ck::math::exp(-x)); }; }; @@ -324,7 +457,8 @@ struct TanH __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); y = ck::math::tanh(x); @@ -335,17 +469,116 @@ struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); + }; + + const float beta_; +}; + +struct SoftRelu +{ + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + template __host__ __device__ void operator()(T& y, const T& x) const { static_assert(is_same::value || is_same::value || - is_same::value, + is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; + } + const float alpha_; +}; - y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); - }; +struct Power +{ + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; - float beta_ = 1.0f; + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = ck::math::pow(shifted_scaled_x, casted_gamma); + } + const float alpha_; + const float beta_; + const float gamma_; +}; + +struct ClippedRelu +{ + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); + } + const float alpha_; + const float beta_; +}; + +struct LeakyRelu +{ + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + const float alpha_; +}; + +struct Elu +{ + Elu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * ck::math::expm1(x); + } + const float alpha_; }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp new file mode 100644 index 0000000000000000000000000000000000000000..47573107cf4aa9296e3724e1b470223e623c4b35 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -0,0 +1,704 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/workgroup_synchronization.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_batchnorm_forward( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count, + p_control, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseMultiblockBatchNormForward +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + + using ThreadwiseWelford1 = + ThreadwiseWelford; + + using ThreadwiseWelford2 = + ThreadwiseWelfordMerge; + + using BlockwiseWelford1 = BlockwiseWelford; + + using BlockwiseWelford2 = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + if(block_local_id == 0) + gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + StaticBuffer + x_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + StaticBuffer count_thread_buf; + + StaticBuffer + tmp_mean_thread_buf; + StaticBuffer + tmp_var_thread_buf; + StaticBuffer tmp_count_thread_buf; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: each workgroup does local welford reduction + + auto threadwise_welford_1 = ThreadwiseWelford1(); + threadwise_welford_1.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + count_thread_buf(I) = threadwise_welford_1.cur_count_; + BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 2: each workgroup writes its local welford result to workspace memory + + auto mean_global_val_buf = + make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto var_global_val_buf = + make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto count_global_val_buf = + make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_mean_var_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_count_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + if(thread_k_cluster_id == 0) + { + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_count_grid_desc_m_g, + mean_global_val_buf); + + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_count_grid_desc_m_g, + var_global_val_buf); + + threadwise_count_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + count_thread_buf, + mean_var_count_grid_desc_m_g, + count_global_val_buf); + }; + + gms_barrier(&p_control[blkgroup_id * 2]); + + if(block_local_id == 0) + gms_reset(&p_control[blkgroup_id * 2]); + + // Step 3: each workgroup reads welford results from workspace memory and does final welford + // reduction + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + count_thread_buf(I) = 0; + }); + + constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_count_thread_buf); + + ThreadwiseWelford2::Run(tmp_mean_thread_buf, + tmp_var_thread_buf, + tmp_count_thread_buf, + mean_thread_buf, + var_thread_buf, + count_thread_buf); + + reducedSize += KThreadClusterSize; + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 4: do normalization using the mean/variance + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + threadwise_x_load.SetSrcSliceOrigin( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k); + } + + // Step 5: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 6: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; // namespace ck + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp index a72a4ee068f2bbcb1a915df0368c9b05caa298f9..4e182ec29ddd1be0ff83b72a1da956ab0353a37a 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -118,8 +118,8 @@ struct GridwiseReduceSecondHalfBatchNormBackwardFinal static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp index 08cb0dd191303813a9283b19e38e2952fa0ccf27..a82a173500d21305cf5b643553d4ddf544a62d0d 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -161,7 +161,7 @@ struct GridwiseMultiblockWelfordFirstHalf PassThroughOp, ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, InMemoryDataOperationEnum::Set, 1, @@ -180,7 +180,7 @@ struct GridwiseMultiblockWelfordFirstHalf PassThroughOp, ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, InMemoryDataOperationEnum::Set, 1, diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp rename to include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp index 548d7fd40ac70fe7fde5e5973a177600b0941c5c..672be91a79ed87322ce228cf558e33b505fe6f5c 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, - index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_variance, @@ -59,7 +58,6 @@ __global__ void kernel_welford_second_half_batchnorm_forward_final( mean_var_grid_desc_m, blkgroup_size, num_xy_k_block_tile_iteration, - num_mean_var_count_k_block_tile_iteration, epsilon, p_in_welford_mean, p_in_welford_variance, @@ -123,8 +121,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); @@ -152,7 +150,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal const MeanVarGridDesc_M& mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, - index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType* const __restrict__ p_in_welford_mean, const MeanVarDataType* const __restrict__ p_in_welford_variance, @@ -223,7 +220,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal decltype(thread_buffer_desc_m_1), ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, 1, true>( @@ -239,7 +236,7 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal decltype(thread_buffer_desc_m_1), ThreadBufferLengths_M_1, Sequence<0, 1>, - 1, + 0, 1, 1, true>( @@ -257,9 +254,6 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal const auto welford_count_global_val_buf = make_dynamic_buffer( p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); - constexpr auto mean_var_count_thread_copy_step_m_k = - make_multi_index(0, KThreadClusterSize * 1); - // Step 1: do final welford reduction to get mean and variance static_for<0, MThreadSliceSize, 1>{}([&](auto I) { @@ -268,8 +262,11 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal welford_count_thread_buf(I) = 0; }); - for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; - ++reducedTiles) + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) { threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, welford_mean_global_val_buf, @@ -296,6 +293,8 @@ struct GridwiseWelfordSecondHalfBatchNormForwardFinal welford_var_thread_buf, welford_count_thread_buf); + reducedSize += KThreadClusterSize; + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k); threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, diff --git a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp index 42b7e172b239e69ae0ca840dd80d8bb203a5480f..2d5dc90bfb3744709bc51219c634d875994dda65 100644 --- a/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp +++ b/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -115,8 +115,8 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); - using ThreadReduceSrcDesc_M_1 = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index fe4dce2b9dccaa9533b6182ba3259755d7ac8d8b..7bb47e9d3c4b3a38c14ced3059d9efe3c3269763 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,8 @@ #include "ck/utility/number.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#include +#include namespace ck { @@ -109,30 +111,57 @@ struct BlockToCTileMap_M00_N0_M01 // Rows of column-vectors // This C-tile map dynamically adjusts M01 when C-tile index is out of range -template -struct BlockToCTileMap_M00_N0_M01Adapt +template +struct BlockToCTileMap_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_M00_N0_M01Adapt { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { + } + + template __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) - : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + : BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) { } - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { - const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); - const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); - const index_t grid_size = M0 * N0; + return M0 * N0; + } - return grid_size; + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; } template @@ -140,8 +169,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt { auto block_1d_id = idx_top[I0]; - const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); - const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); block_1d_id = block_1d_id % (M0 * N0); // swallow batch index @@ -209,11 +238,18 @@ struct BlockToCTileMap_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } - private: + index_t M_; + index_t N_; index_t M01_; - CGridDesc_M_N c_grid_desc_m_n_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +{ + using BlockToCTileMap_M00_N0_M01Adapt:: + BlockToCTileMap_M00_N0_M01Adapt; }; // 2D slices of column-vectors in 3D space @@ -551,7 +587,8 @@ struct OffsettedBlockToCTileMap { using underlying_type = UnderlyingBlockToCTileMap; - OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t block_start) { block_to_ctile_map_ = block_to_ctile_map; block_start_ = block_start; @@ -587,4 +624,454 @@ struct OffsettedBlockToCTileMap index_t block_start_; }; +/** + * @brief Simple tile mapping which creates 3D grid of block of threads. + * + * @paragraph Description + * This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread + * blocks. The first 2D are regular 2D tiles created by division of output GEMM + * dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension, + * which denotes the number of blocks we use to divide work on GEMM K dimension onto. + * + * @tparam MPerBlock Output block tile size in M dimension. + * @tparam NPerBlock Output block tile size in N dimension. + */ +template +struct BlockToCTileMap_3DGrid_KSplit +{ + + __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default; + + __host__ __device__ constexpr auto + CalculateGridSize(index_t M, index_t N, index_t k_split) const + { + // Create 3D grid + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return std::make_tuple(N0, M0, k_split); + } + + template + __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const + { + return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } +}; + +enum StreamKReductionStrategy +{ + Atomic = 0, // sk block use atomic to do reduction + Reduction, // let some workgroup responsible for doing the reduction operation +}; + +template +struct BlockToCTileMap_GemmStreamK +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv eqav_tiles_big; // for reduction + MDiv eqav_tiles_little; // for reduction + + // MDiv tile_swizzle_sub_m_rem; + //-------------------------------------- + + // prefer construct on host + BlockToCTileMap_GemmStreamK(uint32_t m, + uint32_t n, + uint32_t k, + uint32_t num_cu, + uint32_t occupancy, + uint32_t sk_blocks = 0xffffffff) + { + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + // one cu can hold one wg at one time, from the whole chip's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + uint32_t full_dispatches = num_tiles / num_cu; + uint32_t full_dispatch_tiles = full_dispatches * num_cu; + uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatche_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; // TODO: single occ seems better + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else + { + // others, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatche_tiles + num_cu; + } + + // uint32_t dp_iters_per_block = k_iters_per_tile.get(); + uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; + uint32_t dp_num_blocks = 0; + + { + uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + uint32_t dp_for_sk_iters = k_iters_per_tile.get(); + + uint32_t best_sk_score = + std::numeric_limits::max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // TODO: carefully adjust this parameter + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + // give a chance to control num of sk blocks + sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + +#if 0 + printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " + "sk_num_blocks:%d, " + "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " + "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " + "sk_tiles:%u, workspace(acc float):%u\n", + num_cu, + occupancy, + get_grid_dims().x, + num_tiles, + dp_tiles, + sk_num_big_blocks, + sk_num_blocks, + sk_total_iters, + dp_start_block_idx, + dp_iters_per_block, + dp_num_blocks, + k_iters_per_tile.get(), + k_iters_per_big_block, + reduction_start_block_idx, + get_sk_tiles(), + get_workspace_size(sizeof(float))); +#endif + } + + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ dim3 get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + } + else + return dim3(reduction_start_block_idx, 1, 1); + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& eqav_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1; + uint32_t quo_, rem_; + eqav_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_eqav_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index aa34cfbf84ab9425a236646400c1a965c8004f74..206ea00b9d0a94dcd5f12c02948d8f84f9223fb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -101,8 +101,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -346,14 +346,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; - using CountGridDescriptor_MBlock_MPerBlock_NBlock = remove_cvref_t; - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using CountGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; @@ -518,6 +522,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp index fbe89e7e5e5936c2db1470e63f922cc7f6c5ff2a..69468c25befe036123cf1796798e6a59ea47f74b 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp index bdebe3816f22d6a4f10e913210abc8c8347e20d0..bd1e0585fc96ad29ee9e5336993f2ba6570fb31d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp index 1313ec9435e782ad9cc9b92580464f6d5257b93e..fc4f27e33b15b6bfcc7ce942764f9aa5e26cfe1f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp index 6836a66047531d9eef09dbffc58188f7696a783a..203be3c42d9d23c935d10f34ceb44245f7875ad0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index 6c5bd29f9b5ee70b995342f04d59b8ed042b1194..910c926c7e470e0a7daa1cd25f57716f5ce4af6b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,7 @@ namespace ck { template (in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_value_global, - p_in_index_global, - beta, - p_out_value_global, - p_out_index_global); + GridwiseReduction::template RunWithIndex( + in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); }; }; @@ -232,7 +234,7 @@ struct GridwiseReduction_mk_to_m_threadwise reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); }; - template + template __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, const OutGridDesc_M& out_grid_desc_m, const InElementwiseOperation& in_elementwise_op, @@ -390,6 +392,18 @@ struct GridwiseReduction_mk_to_m_threadwise indexStart += KThreadSliceSize; reducedLength += KThreadSliceSize; } while(reducedLength < toReduceLength); + + if constexpr(TransformIndexKtoGlobal) + { + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + const auto coord = make_tensor_coordinate( + in_grid_desc_m_k, + make_multi_index(thread_global_1d_id * MThreadSliceSize + I, + accu_index_buf(I))); + + accu_index_buf(I) = coord.GetOffset(); + }); + } }; // for indiced operation, acc_elementwise_op shoud do nothing diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp index fccb127d0f342911ed6bb1f8396c9d9cd0eec921..9469fa7bc7b27351f393caed43a9bc12a2b8780d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -102,8 +102,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -286,8 +286,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -627,7 +628,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index b9f4a3080a0e17ca0ba169bc8b83c006acb2305d..a0924ae3b0525b6605813a8ac58c40cd7ce4c18d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,6 +67,8 @@ template ; + using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; - using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2E1TileMap = remove_cvref_t; @@ -710,13 +715,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId I1, // NBlockID - I1, // MRepeat - I1, // NRepeat - I1, // MWaveId - I1, // NWaveId - I1, // MPerXdl - I1, // NGroupNum - I1, // NInputNum + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum n4)); // registerNum auto d0s_thread_buf = generate_tuple( @@ -732,8 +737,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle const auto wave_id = GetGemm0WaveIdx(); const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 - constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, n2, n4)); + static_assert(CDE0BlockTransferSrcScalarPerVector <= n4, + "vector load must be not greater than n4"); + static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0); auto d0s_threadwise_copy = generate_tuple( [&](auto i) { @@ -742,10 +748,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle A0B0B1DataType, decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, - 9, - n4, + 9, // CDE0BlockTransferSrcVectorDim + CDE0BlockTransferSrcScalarPerVector, 1, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], make_multi_index(block_work_idx[I0], // MBlockId @@ -865,7 +880,12 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle Gemm1KPack, false, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{} + Gemm1KPack * XdlopsGemm{} .K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin @@ -898,66 +918,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle blockwise_gemm0, acc0_thread_buf, num_k_block_main_loop); - // bias+gelu + // multiple d + if constexpr(NumD0Tensor) { - static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) { - static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) { - static_for<0, n2, 1>{}([&](auto groupid) { - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).Run( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - d0s_grid_buf[i], - d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - d0s_thread_buf(i)); - }); - - static_for<0, n4, 1>{}([&](auto i) { - constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( - make_tuple(mr, nr, groupid, i)); - - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - return d0s_thread_buf[iSrc][i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { - return acc0_thread_buf(Number{}); - }, - Number<2>{}); - - unpack2(cde0_element_op, dst_data_refs, src_data_refs); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0)); - }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc0_thread_buf(i); }, + Number<2>{}); + + unpack2(cde0_element_op, dst_data_refs, src_data_refs); }); static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); }); } + else + { + static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); + } // gemm1 { // TODO: explore using dynamic buffer for a1 thread buffer diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index 6a6f19d71ef523697d69719de75b70bce3953e7e..bc76d4cc4fb9e905cdb1547273866b66575ff334 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,7 +80,8 @@ template + int D0sTransferSrcScalarPerVector = 4, + PipelineVersion PipelineVer = PipelineVersion::v1> struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle { static_assert(LoopSched == LoopScheduler::Default, @@ -113,8 +114,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -367,12 +368,14 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Number{}); } - using D0sGridPointer = decltype(MakeD0sGridPointer()); - using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; - using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -621,13 +624,13 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId I1, // NBlockID - I1, // MRepeat - I1, // NRepeat - I1, // MWaveId - I1, // NWaveId - I1, // MPerXdl - I1, // NGroupNum - I1, // NInputNum + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum n4)); // registerNum auto d0s_thread_buf = generate_tuple( @@ -644,9 +647,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle const auto wave_id = GetGemm0WaveIdx(); const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 - constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, n2, n4)); - auto d0s_threadwise_copy = generate_tuple( [&](auto i) { using D0DataType = remove_cvref_t>; @@ -655,10 +655,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle D0DataType, decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, 9, - n4, + D0sTransferSrcScalarPerVector, 1, false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], make_multi_index(block_work_idx[I0], // MBlockId @@ -785,7 +794,8 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin @@ -884,62 +894,35 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle // multiple d if constexpr(NumD0Tensor) { - static_for<0, MXdlPerWave, 1>{}([&](auto mr) { - static_for<0, NXdlPerWave, 1>{}([&](auto nr) { - static_for<0, n2, 1>{}([&](auto groupid) { - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).Run( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - d0s_grid_buf[i], - d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - d0s_thread_buf(i)); - }); - - static_for<0, n4, 1>{}([&](auto i) { - constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( - make_tuple(mr, nr, groupid, i)); - - // get reference to src data - const auto src_data_refs = generate_tie( - // return type should be lvalue - [&](auto iSrc) -> const auto& { - return d0s_thread_buf[iSrc][i]; - }, - Number{}); - - // get reference to dst data - auto dst_data_refs = generate_tie( - // return type should be lvalue - [&](auto) -> auto& { - return acc_thread_buf(Number{}); - }, - Number<2>{}); - - unpack2(c0de_element_op, dst_data_refs, src_data_refs); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); - }); - }); - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - d0s_threadwise_copy(i).MoveSrcSliceWindow( - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0)); - }); + static_assert(NXdlPerWave == n0); + static_assert(MXdlPerWave == m0); + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc_thread_buf(i); }, + Number<2>{}); + + unpack2(c0de_element_op, dst_data_refs, src_data_refs); }); static_for<0, NumD0Tensor, 1>{}([&](auto i) { d0s_threadwise_copy(i).MoveSrcSliceWindow( d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], - make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); }); } else diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index d6d2051113ff4fdebb74689293b00c532aae9c68..afb2ad2e760396c200930254f871ff13e032a91a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -113,8 +113,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; template __host__ __device__ static constexpr auto @@ -300,8 +300,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; @@ -648,7 +649,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle Gemm1KPack, true, // TransposeC Gemm1KPack, // AMmaKStride - Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ // BMmaKStride make_tuple(0, 0, 0, 0)}; // A_origin diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp index ede6a96dc9f53045b5fc029f4c4f4f6aa4fa4817..ed1ffdd85765e84feb2f6824a858ab842487c18c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp index 33c45a0f037b4c81fbaf810cc10e55f71a2ab254..b6c83af13a18639cf7f8312dd416dbf4d9cb5297 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp deleted file mode 100644 index 2369f51795d9a3bcf5a28ad9de702d75e24fac48..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_contraction_dlops_v1r2.hpp +++ /dev/null @@ -1,662 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP -#define CK_GRIDWISE_CONTRACTION_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" -#include "blockwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_contraction_dlops_v1r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, - const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, - const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) -{ - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_grid_desc_gk0_gm0_gm10_gm11_gk1, - b_grid_desc_gk0_gn0_gn10_gn11_gk1, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // GM0 and GN0 need to known at compile-time - static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); - static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); - static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // lds max alignment - // TODO: part of them should be moved into blockwise-gemm - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = GK1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! GM0 and GN0 need to be known at compile-time"); - - const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); - const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return ( - (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && - GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && - GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && - GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && - GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && - GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && - GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && - GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && - GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && - GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && - (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr index_t GM11 = GM1PerBlockGM11; - constexpr index_t GN11 = GN1PerBlockGN11; - - const index_t GM10 = GM1 / GM11; - const index_t GN10 = GN1 / GN11; - - const index_t grid_size = GM10 * GN10; - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) - { - const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) - { - const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( - const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) - { - const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); - const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); - - const auto GM11 = Number{}; - const auto GM10 = GM1 / GM11; - - const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_tensor_descriptor( - a_grid_desc_gk0_gm0_gm1_gk1, - make_tuple(make_pass_through_transform(GK0), - make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11)), - make_pass_through_transform(GK1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return a_grid_desc_gk0_gm0_gm10_gm11_gk1; - } - - __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( - const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) - { - const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); - const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); - - const auto GN11 = Number{}; - const auto GN10 = GN1 / GN11; - - const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_tensor_descriptor( - b_grid_desc_gk0_gn0_gn1_gk1, - make_tuple(make_pass_through_transform(GK0), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11)), - make_pass_through_transform(GK1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return b_grid_desc_gk0_gn0_gn10_gn11_gk1; - } - - __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - constexpr auto BM = GM0 * GM11; - constexpr auto BN = GN0 * GN11; - - constexpr auto BM1 = - Number{}; - constexpr auto BN1 = - Number{}; - - constexpr auto BM0 = BM / BM1; - constexpr auto BN0 = BN / BN1; - - const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_tensor_descriptor( - c_grid_desc_gm0_gm1_gn0_gn1, - make_tuple(make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11)), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - const auto c_gm10_bm_gn10_bn_grid_desc = transform_tensor_descriptor( - c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_merge_transform(make_tuple(GM0, GM11)), - make_pass_through_transform(GN10), - make_merge_transform(make_tuple(GN0, GN11))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_tensor_descriptor( - c_gm10_bm_gn10_bn_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_unmerge_transform(make_tuple(BM0, BM1)), - make_pass_through_transform(GN10), - make_unmerge_transform(make_tuple(BN0, BN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; - } - - __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( - const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) - { - const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); - const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(GM10, GN10))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_grid_block_cluster_blockid_to_gm10_gn10; - } - - using AGridDesc_GK0_GM0_GM10_GM11_GK1 = - decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); - using BGridDesc_GK0_GN0_GN10_GN11_GK1 = - decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); - using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = - decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); - using CGridBlockCluster_BlockId_To_GM10_GN10 = - decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, - const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, - const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); - - const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); - - // divide block work by [GM10, GN10] - const auto c_gm10_gn10_block_cluster_idx = - c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); - const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); - - // lds max alignment - // TODO: part of them should be moved into blockwise-gemm - // TODO: change this. I think it needs multi-dimensional alignment - constexpr auto max_lds_align = GK1; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0, I1, Number{}, GK1), - max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0, I1, Number{}, GK1), - max_lds_align); - - // A matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto a_block_desc_gk0_bm_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); - - // B matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto b_block_desc_gk0_bn_gk1 = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); - - static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == - a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == - b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), - "wrong!"); - - // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< - BlockSize, - InMemoryDataOperationEnum::Set, - Sequence, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), - decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths - ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder - false, - true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - make_multi_index(0, 0, igm10, 0, 0), - a_block_desc_gk0_gm0_gm10_gm11_gk1, - make_multi_index(0, 0, 0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< - BlockSize, - InMemoryDataOperationEnum::Set, - Sequence, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), - decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths - BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder - Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder - false, - true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - make_multi_index(0, 0, ign10, 0, 0), - b_block_desc_gk0_gn0_gn10_gn11_gk1, - make_multi_index(0, 0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS - // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS - // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - decltype(a_block_desc_gk0_bm_gk1), - decltype(b_block_desc_gk0_bn_gk1), - BM1PerThreadBM11, - BN1PerThreadBN11, - BK0PerThread, - BM10BN10ThreadClusterBM10Xs, - BM10BN10ThreadClusterBN10Xs, - BM1PerThreadBM11, - BN1PerThreadBN11>{}; - - constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = - decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - - constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_thread_desc_bm0_bm1_bn0_bn1, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t gk0_block_on_grid = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); - - gk0_block_on_grid += 2 * GK0PerBlock; - } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead( - b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridStepHacks{}); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = - blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), - decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), - Sequence<1, - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], - 1, - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], - c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - make_multi_index(igm10, - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], - ign10, - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], - c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} - .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp index 8b82b65540d1c5612002549f72708be07b649dae..d686c14b350953a48bf132694251ac677fb47c4f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 05257d16275e5f9817ec5d9cbbd06b133f0b5e05..bf0e8c186c75cef05b08c7ef9e1bf16c016830a3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp index b09a73590239609908265cd6b72e3e5bb26127aa..3ea72b85345869a938e52f872c2c586c6e36170f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index bebcdceb43521a3b489e0b4d04e847bcc6661b3d..9f5df8bdbf6a65f6b3a012954e7cff9d9db1ba6f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -67,7 +67,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -191,8 +191,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -346,14 +346,17 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; - using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using ReduceGridDescriptor_MBlock_MPerBlock = remove_cvref_t; @@ -501,6 +504,7 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp index 9c68b4f5c3ffbb13cd44701121fca7cc113d76cd..27f48a84ba72fe284a964ac295540eac1c4f6766 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index d46aea5e22d131a588c90fc36189c34bd5755242..1da7236978eb042c4270179dcc502ea3593021a2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -274,7 +274,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 const auto c_m0_n0_block_cluster_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - // HACK: this force index data into SGPR + // HACK: this forces index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); @@ -472,7 +472,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); @@ -992,7 +992,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step); - // LDS doubel buffer: load next data from device mem + // LDS double buffer: load next data from device mem a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp deleted file mode 100644 index 84e033e1e91bb9873ebed66906c6da5a287bf1bd..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r2.hpp +++ /dev/null @@ -1,608 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP -#define CK_GRIDWISE_GEMM_DLOPS_V1R2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r2.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v1r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AKM0M1GridDesc a_k_m0_m1_grid_desc, - const BKN0N1GridDesc b_k_n0_n1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k_m0_m1_grid_desc, - b_k_n0_n1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseGemmDlops_km_kn_mn_v1r2 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, - const BKNGridDesc& b_k_n_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); - const auto K = a_k_m_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K == b_k_n_grid_desc.GetLength(I0)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); - } - - __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) - { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) - { - const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) - { - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - - const auto M1 = Number{}; - const auto M0 = M / M1; - - const auto a_k_m0_m1_grid_desc = transform_tensor_descriptor( - a_k_m_grid_desc, - make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - return a_k_m0_m1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) - { - const auto K = b_k_n_grid_desc.GetLength(I0); - const auto N = b_k_n_grid_desc.GetLength(I1); - - const auto N1 = Number{}; - const auto N0 = N / N1; - - const auto b_k_n0_n1_grid_desc = transform_tensor_descriptor( - b_k_n_grid_desc, - make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - return b_k_n0_n1_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - constexpr auto M11 = - Number{}; - constexpr auto N11 = - Number{}; - - constexpr auto M10 = M1 / M11; - constexpr auto N10 = N1 / N11; - - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), - make_unmerge_transform(make_tuple(N0, N10, N11))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - - return c_m0_m10_m11_n0_n10_n11_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_m0_n0_block_cluster_adaptor; - } - - using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); - using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AKM0M1GridDesc& a_k_m0_m1_grid_desc, - const BKN0N1GridDesc& b_k_n0_n1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); - - const auto K = a_k_m0_m1_grid_desc.GetLength(I0); - - // divide block work by [M, N] - const auto c_m0_n0_block_cluster_idx = - cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); - const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m0_m1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n0_n1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M0_M1, - ABlockTransferThreadClusterLengths_K_M0_M1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k_m0_m1_grid_desc), - decltype(a_k_m0_m1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_M1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_k_m0_m1_grid_desc, - make_multi_index(0, im0, 0), - a_k_m0_m1_block_desc, - make_multi_index(0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N0_N1, - BBlockTransferThreadClusterLengths_K_N0_N1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k_n0_n1_grid_desc), - decltype(b_k_n0_n1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_N1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_k_n0_n1_grid_desc, - make_multi_index(0, in0, 0), - b_k_n0_n1_block_desc, - make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; - constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); - - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( - sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); - - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m0_m1_global_step_hacks = AGridStepHacks{}; - constexpr auto b_k_n0_n1_global_step_hacks = BGridStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k_m0_m1_global_move_slice_window_step_hack = - AGridMoveSliceWindowStepHacks{}; - constexpr auto b_k_n0_n1_global_move_slice_window_step_hack = - BGridMoveSliceWindowStepHacks{}; - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_k_m0_m1_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_k_n0_n1_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_step_hack); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_step_hacks); - b_blockwise_copy.RunRead( - b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = - blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), - Sequence<1, - c_m10_m11_n10_n11_thread_tensor_lengths[I0], - c_m10_m11_n10_n11_thread_tensor_lengths[I1], - 1, - c_m10_m11_n10_n11_thread_tensor_lengths[I2], - c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, - make_multi_index(im0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], - in0, - c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp deleted file mode 100644 index b1dfb0c73fc92a19662e01df31090008cdef87ff..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v2.hpp +++ /dev/null @@ -1,461 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_V2_HPP -#define CK_GRIDWISE_GEMM_V2_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "blockwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -struct GridwiseGemmDlops_km_kn_mn_v3 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto E = EPerBlock * 3 * 3; - - constexpr auto max_lds_align = - math::lcm(Number{}, Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); - - return a_block_space_size * sizeof(FloatAB); - } - - template - __device__ void Run(const AGlobalDesc& a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc& b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc& c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - FloatAB* __restrict__ p_shared_block, - integral_constant, - integral_constant) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e_k_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize()); - - constexpr auto E = EPerBlock * 3 * 3; - - // const auto E = a_e_k_global_desc.GetLength(I0); - const auto K = a_e_k_global_desc.GetLength(I1); - - const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); - const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); - const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); - -// divide block work by [M, N] -#if 0 - const auto ho_block_work_num = Ho / Number{}; - const auto wo_block_work_num = Wo / Number{}; - const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; - - const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; - const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; - - const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; - const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; -#else - // Hack: this force result into SGPR - const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); - const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); - const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; - - const index_t k_block_work_id = - __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); - const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; - - const index_t ho_block_work_id = - __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); - const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; -#endif - - // lds max alignment - constexpr auto max_lds_align = - math::lcm(Number{}, Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e_k_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - constexpr auto a_e_k_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_e_n_ho_wo_block_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - - auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const auto k_thread_id = c_thread_mtx_index.k; - const auto ho_thread_id = c_thread_mtx_index.h; - const auto wo_thread_id = c_thread_mtx_index.w; - - const index_t k_block_data_on_global = k_block_work_id * KPerBlock; - const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; - const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; - - const index_t ho_thread_data_on_global = - ho_block_data_on_global + ho_thread_id * HoPerThread; - const index_t wo_thread_data_on_global = - wo_block_data_on_global + wo_thread_id * WoPerThread; - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_e_k_global_desc), - decltype(a_e_k_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1>, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_e_k_global_desc, - make_multi_index(0, k_block_data_on_global), - a_e_k_desc, - make_multi_index(0, 0)); - - constexpr auto b_e_n_ho_wo_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( - Number{}, Number<1>{}, Number{}, Number{})); - - auto b_threadwise_transfer = - ThreadwiseTensorSliceTransfer_v2, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - 1, - true>( - b_e_n_ho_wo_global_desc, - make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); - - auto a_block_buf = make_dynamic_buffer( - p_shared_block, a_e_k_desc.GetElementSpaceSize()); - - // register allocation for output - StaticBuffer - c_thread_buf; - - // initialize output thread tensor - ThreadwiseTensorSliceSet_v1>{} - .Run(c_k_n_ho_wo_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - - constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e_k_global_step_hacks = AGlobalStepHacks{}; - constexpr auto b_e_n_ho_wo_global_step_hacks = BGlobalStepHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_e_k_global_move_slice_window_step_hack = AGlobalMoveSliceWindowStepHacks{}; - constexpr auto b_e_n_ho_wo_global_move_slice_window_step_hack = - BGlobalMoveSliceWindowStepHacks{}; - - // double regsiter buffer for b - StaticBuffer - b_thread_even_buf, b_thread_odd_buf; - - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead(a_e_k_global_desc, a_global_buf, a_e_k_global_step_hacks); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_even_buf, - b_e_n_ho_wo_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e_k_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainKBlockLoop) - { - index_t e_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_odd_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on current data - // TODO: @Zhang Jing: blockwise gemm should be able to move slice window - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_even_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - e_block_data_begin += 2 * EPerBlock; - - } while(e_block_data_begin < E - 2 * EPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, - b_thread_slice_copy_step); - - b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, - b_global_buf, - b_e_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - b_thread_odd_buf, - b_e_n_ho_wo_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveASliceWindow(a_e_k_block_desc, make_tuple(EPerBlock, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - - // output: register to global memory - { - // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor - constexpr auto c_k_n_ho_wo_global_tensor_step_hacks = CGlobalStepHacks{}; - - const index_t k_thread_data_on_global = - k_block_data_on_global + k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>( - c_k_n_ho_wo_global_desc, - make_multi_index( - k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) - .Run(c_k_n_ho_wo_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - c_k_n_ho_wo_global_desc, - c_global_buf, - c_k_n_ho_wo_global_tensor_step_hacks); - } - } - - // pass tensor descriptor by reference - template - __device__ void Run(const AGlobalDesc& a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc& b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc& c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - p_shared_block, - integral_constant{}, - integral_constant{}); - } - - // pass tensor descriptors by their pointers - template - __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const BGlobalDesc* p_b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const CGlobalDesc* p_c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - const auto a_e_k_global_desc = *p_a_e_k_global_desc; - const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; - const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - integral_constant{}, - integral_constant{}); - } - - // pass tensor descriptors by void* - template - __device__ void Run(const void* p_a_e_k_global_desc, - const FloatAB* __restrict__ p_a_global, - const void* p_b_e_n_ho_wo_global_desc, - const FloatAB* __restrict__ p_b_global, - const void* p_c_k_n_ho_wo_global_desc, - FloatC* __restrict__ p_c_global, - integral_constant, - integral_constant) const - { - const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); - const auto b_e_n_ho_wo_global_desc = - *reinterpret_cast(p_b_e_n_ho_wo_global_desc); - const auto c_k_n_ho_wo_global_desc = - *reinterpret_cast(p_c_k_n_ho_wo_global_desc); - - Run(a_e_k_global_desc, - p_a_global, - b_e_n_ho_wo_global_desc, - p_b_global, - c_k_n_ho_wo_global_desc, - p_c_global, - integral_constant{}, - integral_constant{}); - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp deleted file mode 100644 index ace844338414668030e5dd6aeeaf559e1a7c1060..0000000000000000000000000000000000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v3.hpp +++ /dev/null @@ -1,1597 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_GRIDWISE_GEMM_V3_HPP -#define CK_GRIDWISE_GEMM_V3_HPP - -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "threadwise_tensor_slice_set.hpp" -#include "blockwise_gemm_dlops_v3.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActiv(p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3_resize_add( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_d_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActivResizeAdd(p_a_grid, - p_b_grid, - p_bias_grid, - p_d_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_gemm_dlops_v3_maxpool( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatC* __restrict__ p_bias_grid, - FloatC* __restrict__ p_c_grid, - FloatC* __restrict__ p_d_grid, - const AGridDesc_E0_E1_K0_K1_E2 a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W cblockid_to_k_n_h_w_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::ConvBiasActivMaxpool(p_a_grid, - p_b_grid, - p_bias_grid, - p_c_grid, - p_d_grid, - p_shared_block, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseGemmDlops_km_kn_mn_v3 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - - static constexpr auto E1 = Number{}; - static constexpr auto E2 = Number{}; - static constexpr auto K2 = Number{}; - - static constexpr auto NPerBlock = I1; - - static constexpr FloatAcc alpha = 0.3; - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = Number{}; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_e0_e1_k1_e2_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(I1, Number{}, Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = math::integer_least_multiple( - a_e0_e1_k1_e2_block_desc.GetElementSpaceSize(), max_lds_align); - - return a_block_space_size * sizeof(FloatAB); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - - const auto K0 = K / KPerBlock; - const auto N0 = N / NPerBlock; - const auto H0 = Ho / HoPerBlock; - const auto W0 = Wo / WoPerBlock; - - const index_t grid_size = K0 * N0 * H0 * W0; - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainE0BlockLoop(const index_t E0) - { - const bool has_main_e0_block_loop = E0 > 1; - - return has_main_e0_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasMainE1BlockLoop() - { - const bool has_main_e1_block_loop = ((E1 + E1PerBlock) / (2 * E1PerBlock)) > 1; - - return has_main_e1_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailE1BlockLoop() - { - const bool has_double_tail_e1_block_loop = (E1 / E1PerBlock) % 2 == 0; - - return has_double_tail_e1_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAE0E1K0K1E2GridDescriptor(const AGridDesc_E0_E1_K_E2& a_e0_e1_k_e2_grid_desc) - { - const auto E0 = a_e0_e1_k_e2_grid_desc.GetLength(I0); - const auto K = a_e0_e1_k_e2_grid_desc.GetLength(I2); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto a_e0_e1_k0_k1_e2_grid_desc = transform_tensor_descriptor( - a_e0_e1_k_e2_grid_desc, - make_tuple(make_pass_through_transform(E0), - make_pass_through_transform(E1), - make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - return a_e0_e1_k0_k1_e2_grid_desc; - } - - __host__ __device__ static constexpr auto MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor( - const BGridDesc_E0_E1_N_Ho_Wo_E2& b_e0_e1_n_ho_wo_e2_grid_desc) - { - const auto E0 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I0); - // const auto E1 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I1); - const auto N = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I2); - const auto Ho = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I3); - const auto Wo = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I4); - // const auto E2 = b_e0_e1_n_ho_wo_e2_grid_desc.GetLength(I5); - - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Ho / (H1 * H2); - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Wo / (W1 * W2); - - const auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc = - transform_tensor_descriptor(b_e0_e1_n_ho_wo_e2_grid_desc, - make_tuple(make_pass_through_transform(E0), - make_pass_through_transform(E1), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2)), - make_pass_through_transform(E2)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3, 4, 5>{}, - Sequence<6, 7, 8>{}, - Sequence<9>{})); - - return b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCK0K1NH0H1H2W0W1W2GridDescriptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Ho / (H1 * H2); - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Wo / (W1 * W2); - - const auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc = transform_tensor_descriptor( - c_k_n_ho_wo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeDK0K1NH0H1HxW0W1WxGridDescriptorMaxPool(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) - { - const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); - const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); - const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); - const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto H2 = Number{}; - const auto H1 = Number{}; - const auto H0 = Number{}; - - const auto W2 = Number{}; - const auto W1 = Number{}; - const auto W0 = Number{}; -#else - const auto H2 = HoPerThread / 2; - const auto H1 = HoPerBlock / HoPerThread; - const auto H0 = Hx / (H1 * H2); - - const auto W2 = WoPerThread / 2; - const auto W1 = WoPerBlock / WoPerThread; - const auto W0 = Wx / (W1 * W2); -#endif - - const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( - d_k_n_hx_wx_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeDK0K1NH0H1HxW0W1WxGridDescriptorResizeAdd(const DGridDesc_K_N_Hx_Wx& d_k_n_hx_wx_grid_desc) - { - const auto K = d_k_n_hx_wx_grid_desc.GetLength(I0); - const auto N = d_k_n_hx_wx_grid_desc.GetLength(I1); - const auto Hx = d_k_n_hx_wx_grid_desc.GetLength(I2); - const auto Wx = d_k_n_hx_wx_grid_desc.GetLength(I3); - - const auto K1 = Number{}; - const auto K0 = K / K1; - - const auto H2 = Number{}; - const auto H1 = Number{}; - - const auto W2 = Number{}; - const auto W1 = Number{}; - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto H0 = Number{}; - const auto W0 = Number{}; -#else - const auto H0 = Hx / (H1 * H2); - const auto W0 = Wx / (W1 * W2); -#endif - - const auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc = transform_tensor_descriptor( - d_k_n_hx_wx_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1)), - make_pass_through_transform(N), - make_unmerge_transform(make_tuple(H0, H1, H2)), - make_unmerge_transform(make_tuple(W0, W1, W2))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}, Sequence<6, 7, 8>{})); - - return d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeCBlockIdToKNHoWoBlockClusterAdaptor(const CGridDesc_K_N_Ho_Wo& c_k_n_ho_wo_grid_desc) - { - const auto K = c_k_n_ho_wo_grid_desc.GetLength(I0); - const auto N = c_k_n_ho_wo_grid_desc.GetLength(I1); - const auto Ho = c_k_n_ho_wo_grid_desc.GetLength(I2); - const auto Wo = c_k_n_ho_wo_grid_desc.GetLength(I3); - -#if CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR - const auto K0 = Number{}; - const auto N0 = Number{}; - const auto H0 = Number{}; - const auto W0 = Number{}; -#else - const auto K0 = K / KPerBlock; - const auto N0 = N / NPerBlock; - const auto H0 = Ho / HoPerBlock; - const auto W0 = Wo / WoPerBlock; -#endif - - const auto cblockid_to_k_n_ho_wo_block_cluster_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_k_n_ho_wo_block_cluster_adaptor; - } - - // using AGridDesc_E0_E1_K0_K1_E2 = - // decltype(MakeAE0E1K0K1E2GridDescriptor(AGridDesc_E0_E1_K_E2{})); - // using BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2 = - // decltype(MakeBE0E1NH0H1H2W0W1W2E2GridDescriptor(BGridDesc_E0_E1_N_Ho_Wo_E2{})); - // using CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2 = - // decltype(MakeCK0K1NH0H1H2W0W1W2GridDescriptor(CGridDesc_K_N_Ho_Wo{})); - // using DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx = - // decltype(MakeDK0K1NH0H1HxW0W1WxGridDescriptor(DGridDesc_K_N_Hx_Wx{})); - - using CBlockIdToBlockClusterAdaptor_K_N_H_W = - decltype(MakeCBlockIdToKNHoWoBlockClusterAdaptor(CGridDesc_K_N_Ho_Wo{})); - - template - __host__ __device__ static constexpr auto MakeBiasK0K1GridDescriptor( - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) - { - const auto K0 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I0); - const auto K1 = c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetLength(I1); - - return make_naive_tensor_descriptor_packed(make_tuple(K0, K1)); - } - - __host__ __device__ static constexpr auto MakeCK1NH2W2ThreadDescriptor() - { - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, I1, Number{}, Number{})); - return c_k1_n_h2_w2_thread_gemm_desc; - } - - // using CThreadDesc_K1_N_H2_W2 = decltype(MakeCK1NH2W2ThreadDescriptor()); - - __host__ __device__ static constexpr auto GetBlockWiseGemm() - { - constexpr auto max_lds_align = Number{}; - - constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - constexpr auto b_e1_n_h_w_e2_block_gemm_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - I1, - Number{}, - Number{}, - Number{})); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - - return blockwise_gemm; - } - - __device__ static constexpr auto GetCThreadIndex() - { - auto blockwise_gemm = GetBlockWiseGemm(); - auto c_thread_mtx_index = - blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); - - return c_thread_mtx_index; - }; - - __device__ static constexpr auto GetCBlockIndex( - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor) - { - const auto c_k_n_h_w_block_cluster_idx = - cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - return c_k_n_h_w_block_cluster_idx; - } - - template - __device__ static void BiasOp(BiasGlobalBuff& bias_global_buf, - CThreadBuff& c_thread_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const BiasGridDesc_K0_K1& bias_k0_k1_grid_desc, - const CThreadDesc_K1_N_H2_W2&) - - { - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - - const auto k_thread_id = c_thread_idx[I0]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - constexpr auto bias_k0_k1_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); - - StaticBuffer - bias_thread_buf; - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - auto bias_threadwise_transfer = - ThreadwiseTensorSliceTransfer_v2{}>, - Sequence<0, 1>, - 1, - CThreadTransferDstScalarPerVector, - false, - true>( - bias_k0_k1_grid_desc, make_multi_index(k_block_work_id, k_thread_data_on_global)); - - constexpr auto bias_k0_k1_global_tensor_step_hacks = make_tuple( - make_tuple(Sequence<0>{}, Sequence<0>{}), make_tuple(Sequence<0>{}, Sequence<0>{})); - - bias_threadwise_transfer.Run(bias_k0_k1_grid_desc, - bias_global_buf, - bias_k0_k1_thread_desc, - make_tuple(I0, I0), - bias_thread_buf, - bias_k0_k1_global_tensor_step_hacks); - - static_for<0, KPerThread, 1>{}([&](auto ki) { - static_for<0, HoPerThread, 1>{}([&](auto hi) { - static_for<0, WoPerThread, 1>{}([&](auto wi) { - constexpr index_t c_offset = - c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset(make_tuple(ki, 0, hi, wi)); - c_thread_buf(Number{}) = - c_thread_buf[Number{}] + bias_thread_buf[ki]; - }); - }); - }); - } - - template - __device__ static void Activation(CThreadBuff& c_thread_buf, - const CThreadDesc_K1_N_H2_W2&, - integral_constant) - { - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - static_for<0, c_k1_n_h2_w2_thread_gemm_desc.GetElementSpaceSize(), 1>{}([&](auto i) { - if constexpr(activ_type_ == 1) - { - c_thread_buf(i) = c_thread_buf[i] >= 0 ? c_thread_buf[i] : alpha * c_thread_buf[i]; - } - else if constexpr(activ_type_ == 2) - { - FloatAcc x = 1.0 + exp(-c_thread_buf[i]); - - asm volatile("\n \ - v_rcp_f32 %0, %1 \n" - : "=v"(x) - : "0"(x)); - - c_thread_buf(i) = x; - } - }); - } - - template - __device__ static void - WriteOut(const CThreadBuff& c_thread_buf, - CGlobalBuff& c_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc) - { - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - // hack to control index calculation when iterating over c_k_n_h0_h1_h2_w0_w1_w2_global - // tensor - constexpr auto c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks = CGlobalStepHacks{}; - - constexpr auto c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc), - decltype(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(c_k0_k1_n_h0_h1_h2_w0_w1_w2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - c_global_buf, - c_k_n_h0_h1_h2_w0_w1_w2_global_tensor_step_hacks); - } - - template - __device__ static void - MaxPool(const CThreadBuff& c_thread_buf, - DGlobalBuff& d_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CThreadDesc_K1_N_H2_W2&, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) - { - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - static_assert(HoPerThread % 2 == 0 && WoPerThread % 2 == 0, ""); - - constexpr auto HoPerThread_2 = HoPerThread / 2; - constexpr auto WoPerThread_2 = WoPerThread / 2; - - constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - StaticBuffer - d_thread_buf; - - static_for<0, KPerThread, 1>{}([&](auto ki) { - static_for<0, HoPerThread_2, 1>{}([&](auto hi) { - static_for<0, WoPerThread_2, 1>{}([&](auto wi) { - constexpr index_t d_offset = - d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc.CalculateOffset( - make_tuple(0, ki, 0, 0, 0, hi, 0, 0, wi)); - - constexpr index_t c_offset_0 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2, wi * 2)); - constexpr index_t c_offset_1 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2, wi * 2 + 1)); - constexpr index_t c_offset_2 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2 + 1, wi * 2)); - constexpr index_t c_offset_3 = c_k1_n_h2_w2_thread_gemm_desc.CalculateOffset( - make_tuple(ki, 0, hi * 2 + 1, wi * 2 + 1)); - - d_thread_buf(Number{}) = c_thread_buf[Number{}]; - d_thread_buf(Number{}) = - fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); - d_thread_buf(Number{}) = - fmaxf(c_thread_buf[Number{}], d_thread_buf(Number{})); - d_thread_buf(Number{}) = - fmax(c_thread_buf[Number{}], d_thread_buf(Number{})); - }); - }); - }); - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - InMemoryDataOperationEnum::Set, - 1, - true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - d_thread_buf, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - d_global_buf, - d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); - } - - template - __device__ static void - ResizeAdd(const CThreadBuff& c_thread_buf, - DGlobalBuff& d_global_buf, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const CThreadDesc_K1_N_H2_W2&, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc) - { - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - const auto k_thread_id = c_thread_idx[I0]; - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - constexpr auto HoPerThreadx2 = HoPerThread * 2; - constexpr auto WoPerThreadx2 = WoPerThread * 2; - - constexpr auto d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{})); - - StaticBuffer - d_thread_buf; - - static_for<0, KPerThread, 1>{}([&](auto k_i) { - static_for<0, HoPerThreadx2, 1>{}([&](auto h_i) { - static_for<0, WoPerThreadx2, 1>{}([&](auto w_i) { - d_thread_buf(Number{}) = - c_thread_buf[Number{}]; - }); - }); - }); - - // hack to control index calculation when iterating over d_k_n_ho_wo_global tensor - constexpr auto d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks = DGlobalStepHacks{}; - - const index_t k_thread_data_on_global = k_thread_id * KPerThread; - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc), - decltype(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - InMemoryDataOperationEnum::Add, - 1, - true>(d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - make_multi_index(k_block_work_id, - k_thread_data_on_global, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0)) - .Run(d_k0_k1_n_h0_h1_hx_w0_w1_wx_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0), - d_thread_buf, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - d_global_buf, - d_k_n_h0_h1_hx_w0_w1_wx_global_tensor_step_hacks); - } - - template - __device__ static void - GemmOp(const AGlobalBuff& a_global_buf, - const BGlobalBuff& b_global_buf, - CThreadBuff& c_thread_buf, - FloatAB* __restrict__ p_shared_block, - const CBlockIndex& c_block_idx, - const CThreadIndex& c_thread_idx, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CThreadDesc_K1_N_H2_W2&, - integral_constant) - { - constexpr auto HasMainE1BlockLoop = CalculateHasMainE1BlockLoop(); - constexpr auto HasDoubleTailE1BlockLoop = CalculateHasDoubleTailE1BlockLoop(); - - // const auto c_k_n_h_w_block_cluster_idx = - // GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - // cblockid_to_k_n_h_w_block_cluster_adaptor.CalculateBottomIndex( - // make_multi_index(get_block_1d_id())); - - const index_t k_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I0]); - const index_t n_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I1]); - const index_t ho_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I2]); - const index_t wo_block_work_id = __builtin_amdgcn_readfirstlane(c_block_idx[I3]); - - constexpr auto max_lds_align = Number{}; - - constexpr auto a_e1_k1_e2_block_gemm_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, Number{}), max_lds_align); - - constexpr auto b_e1_n_h_w_e2_block_gemm_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - I1, - Number{}, - Number{}, - Number{})); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = CThreadDesc_K1_N_H2_W2{}; - - auto blockwise_gemm = - BlockwiseGemmDlops_km_kn_m0m1n0n1_v3{}; - // blockwise_gemm.GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id()); - - const auto ho_thread_id = c_thread_idx[I2]; - const auto wo_thread_id = c_thread_idx[I3]; - - constexpr auto a_e0_e1_k0_k1_e2_block_copy_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, I1, Number{}, Number{}), - max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterLengths_E0_E1_K0_K1_E2, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_e0_e1_k0_k1_e2_grid_desc), - decltype(a_e0_e1_k0_k1_e2_block_copy_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3, 4>, - ABlockTransferSrcVectorDim, - 4, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_E2, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - false>(a_e0_e1_k0_k1_e2_grid_desc, - make_multi_index(0, 0, k_block_work_id, 0, 0), - a_e0_e1_k0_k1_e2_block_copy_desc, - make_multi_index(0, 0, 0, 0, 0)); - - constexpr auto a_block_slice_copy_step = make_multi_index(I1, 0, 0, 0, 0); - - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - I1, - I1, - I1, - Number{}, - I1, - I1, - Number{}, - Number{})); - - auto b_threadwise_transfer = ThreadwiseTensorSliceTransfer_v2< - FloatAB, - FloatAB, - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc), - decltype(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc), - Sequence, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - make_multi_index(0, - 0, - n_block_work_id, - ho_block_work_id, - ho_thread_id, - 0, - wo_block_work_id, - wo_thread_id, - 0, - 0)); - - auto a_block_buf = make_dynamic_buffer( - p_shared_block, a_e0_e1_k0_k1_e2_block_copy_desc.GetElementSpaceSize()); - - //// register allocation for output - // StaticBuffer - // c_thread_buf; - - // initialize output thread tensor - ThreadwiseTensorSliceSet_v1>{} - .Run(c_k1_n_h2_w2_thread_gemm_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto b_thread_slice_copy_step = - make_multi_index(0, E1PerBlock, 0, 0, 0, 0, 0, 0, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_e0_e1_k_e2_global_step_hacks = AGlobalStepHacks{}; - constexpr auto b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks = BGlobalStepHacks{}; - - // double regsiter buffer for b - StaticBuffer - b_thread_even_buf, b_thread_odd_buf; - - if constexpr(HasMainE0BlockLoop) - { - const auto E0 = b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetLength(I0); - - index_t e0_block_data_begin = 0; - - do - { - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead( - a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainE1BlockLoop) - { - index_t e1_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - e1_block_data_begin += 2 * E1PerBlock; - - } while(e1_block_data_begin < E1 - 2 * E1PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - - a_blockwise_copy.MoveSrcSliceWindow(a_e0_e1_k0_k1_e2_grid_desc, - a_block_slice_copy_step, - AGlobalMoveSliceWindowStepHacks{}); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(-(E1 - E1PerBlock), 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - e0_block_data_begin += 1; - - } while(e0_block_data_begin < E0); - } - else - { - // LDS double buffer: preload data - { - a_blockwise_copy.RunRead( - a_e0_e1_k0_k1_e2_grid_desc, a_global_buf, a_e0_e1_k_e2_global_step_hacks); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - a_blockwise_copy.RunWrite(a_e0_e1_k0_k1_e2_block_copy_desc, a_block_buf); - } - - __syncthreads(); - - if constexpr(HasMainE1BlockLoop) - { - index_t e1_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - b_threadwise_transfer.MoveSrcSliceWindow( - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_even_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - e1_block_data_begin += 2 * E1PerBlock; - - } while(e1_block_data_begin < E1 - 2 * E1PerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailE1BlockLoop) // if has 2 iteration left - { - b_threadwise_transfer.MoveSrcSliceWindow(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_thread_slice_copy_step, - BGlobalMoveSliceWindowStepHacks{}); - - b_threadwise_transfer.Run(b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - b_global_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_thread_copy_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), - b_thread_odd_buf, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_global_step_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - - blockwise_gemm.MoveABlockSliceWindow(make_tuple(E1PerBlock, 0, 0)); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_buf, b_thread_even_buf, c_thread_buf); - } - } - } - - template - __device__ static void - Conv(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant) - { - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - } - - template - __device__ static void ConvBiasActiv( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - } - - template - __device__ static void ConvBiasActivMaxpool( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_c_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Output - WriteOut(c_thread_buf, - c_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - // MaxPool - MaxPool(c_thread_buf, - d_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k1_n_h2_w2_thread_gemm_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); - } - - template - __device__ static void ConvBiasActivResizeAdd( - const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - const FloatC* __restrict__ p_bias_global, - FloatC* __restrict__ p_d_global, - FloatAB* __restrict__ p_shared_block, - const AGridDesc_E0_E1_K0_K1_E2& a_e0_e1_k0_k1_e2_grid_desc, - const BGridDesc_E0_E1_N_H0_H1_H2_W0_W1_W2_E2& b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - const CGridDesc_K0_K1_N_H0_H1_H2_W0_W1_W2& c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc, - const DGridDesc_K0_K1_N_H0_H1_Hx_W0_W1_Wx& d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc, - const CBlockIdToBlockClusterAdaptor_K_N_H_W& cblockid_to_k_n_h_w_block_cluster_adaptor, - integral_constant, - integral_constant) - { - static constexpr auto activ_type = integral_constant{}; - - const auto bias_k0_k1_grid_desc = - MakeBiasK0K1GridDescriptor(c_k0_k1_n_h0_h1_h2_w0_w1_w2_grid_desc); - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_e0_e1_k0_k1_e2_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc.GetElementSpaceSize()); - auto d_global_buf = make_dynamic_buffer( - p_d_global, d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc.GetElementSpaceSize()); - auto bias_global_buf = make_dynamic_buffer( - p_bias_global, bias_k0_k1_grid_desc.GetElementSpaceSize()); - - constexpr auto c_k1_n_h2_w2_thread_gemm_desc = MakeCK1NH2W2ThreadDescriptor(); - - // register allocation for output - StaticBuffer - c_thread_buf; - - const auto c_k_n_h_w_block_cluster_idx = - GetCBlockIndex(cblockid_to_k_n_h_w_block_cluster_adaptor); - - const auto c_thread_mtx_index = GetCThreadIndex(); - - // GemmOp - GemmOp(a_global_buf, - b_global_buf, - c_thread_buf, - p_shared_block, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - a_e0_e1_k0_k1_e2_grid_desc, - b_e0_e1_n_h0_h1_h2_w0_w1_w2_e2_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc, - integral_constant{}); - - // Bias - BiasOp(bias_global_buf, - c_thread_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - bias_k0_k1_grid_desc, - c_k1_n_h2_w2_thread_gemm_desc); - - // Activ - Activation(c_thread_buf, c_k1_n_h2_w2_thread_gemm_desc, activ_type); - - // Resize_Add - ResizeAdd(c_thread_buf, - d_global_buf, - c_k_n_h_w_block_cluster_idx, - c_thread_mtx_index, - c_k1_n_h2_w2_thread_gemm_desc, - d_k0_k1_n_h0_h1_hx_w0_w1_wx_grid_desc); - } -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6eca77c89cc1751e974d5cac0dd368d55ae37d21 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,702 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89424468c682f7249f108b2515d8f11655818173 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1034 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A0[M, K], A1[M, K] +// input : B0[N, K], B1[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleABD_xdl_cshuffle +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using ComputeDataType = + conditional_t, ck::bhalf_t, ComputeDataType_>; +#else + using ComputeDataType = ComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + { + return generate_tuple( + [&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + Number{}); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + { + return generate_tuple( + [&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + Number{}); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k, + const BsGridDesc_N_K& bs_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = as_grid_desc_m_k[I0].GetLength(I0); + const auto N = bs_grid_desc_n_k[I0].GetLength(I0); + const auto AK = as_grid_desc_m_k[I0].GetLength(I1); + const auto BK = bs_grid_desc_n_k[I0].GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + bool valid = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + valid = + valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB); + valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) && + AK == as_grid_desc_m_k[i].GetLength(I1)); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + valid = + valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB); + valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) && + BK == bs_grid_desc_n_k[i].GetLength(I1)); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + + if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride) + { + return generate_tuple( + [&](auto i) { + using ALayout = remove_cvref_t>; + + return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array& KRaws, + const std::array& NRaws, + const std::array& BsStride) + { + return generate_tuple( + [&](auto i) { + using BLayout = remove_cvref_t>; + + return MakeBGridDescriptor_N_K(KRaws[i], NRaws[i], BsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeDataType, // ComputeDataType for A + ComputeDataType, // ComputeDataType for B + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + gridwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const std::array StrideAs, + const std::array StrideBs, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + using AsGridDesc_M_K = + remove_cvref_t({}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t({}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + AsGridDesc_M_K as_grid_desc_m_k; + BsGridDesc_N_K bs_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumATensor, 1>{}([&](auto j) { + using ALayout = remove_cvref_t>; + + as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K(M, K, StrideAs[j]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto j) { + using BLayout = remove_cvref_t>; + + bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K(N, K, StrideBs[j]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 578665ea85fb723565e597bbec26457ffb449606..5c9f40b51a5ba28810352cf9dcf6d6e2dca1a404 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -92,8 +92,8 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -300,8 +300,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; // Support 2 dimension in the future. Not only M using RGridDescriptor_MBlock_MPerBlock = @@ -469,6 +470,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index d3f81566e93ab6b72f4edc7acfe097ef50e72270..53b2169bc652dd27a1910dc5515a86c9ad7e206d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -36,7 +36,7 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + kernel_grouped_conv_multiple_d_wmma_cshuffle( const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, DsPointer p_ds_grid, @@ -346,8 +346,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + // CheckValidity for kernels without multi D template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n, const Block2CTileMap& block_2_ctile_map) { @@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const auto N = b_grid_desc_k0_n_k1.GetLength(I1); const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - bool valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && - N == ds_grid_desc_m_n[i].GetLength(I1)); - }); - - if(!valid) - { - return false; - } - if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && K1 == b_grid_desc_k0_n_k1.GetLength(I2))) @@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle return true; } + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + bool valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + return CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map); + } + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); @@ -565,10 +578,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle e_grid_desc_m_n); } - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; using DsGridPointer = decltype(MakeDsGridPointer()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 98a71a7c247f17f78c924b013fe0b5764ef257c1..15c30a0dadc51334a6463f80d39ba8c440ad363f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -15,6 +15,9 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + namespace ck { // GEMM: @@ -26,7 +29,9 @@ namespace ck { // E = cde_op(C, D0, D1, ...) // Assume: // D0, D1, ... and E have the same layout -template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -89,18 +97,14 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; - // denorm test fix, required to work around fp16 mfma issue - // we convert fp16->fp32->bf16 and execute bf16 mfma instruction - // when mfma if fixed, remove this section and update - // ABDataTypeAdjusted -> ABDataType throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) - using ABDataTypeAdjusted = - conditional_t, ck::bhalf_t, ABDataType>; +#if CK_WORKAROUND_DENORM_FIX + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using ABDataTypeAdjusted = ABDataType; + using AComputeDataType = AComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -169,8 +173,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -265,13 +269,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); // check consistency of desc - if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) { return false; } @@ -289,13 +296,13 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // check tile size - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) { return false; } // check gridwise gemm pipeline - const auto num_k_loop = K / KPerBlock; + const auto num_k_loop = AK / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { @@ -312,8 +319,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // check tensor size: cannot be larger than 2GB each constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && - b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -331,14 +338,102 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using DsGridPointer = decltype(MakeDsGridPointer()); + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + template - __device__ static void Run(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, @@ -407,8 +502,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + ADataType, + AComputeDataType, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -438,8 +533,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + BDataType, + BComputeDataType, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -467,13 +562,15 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataTypeAdjusted, + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -491,11 +588,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); @@ -760,6 +856,85 @@ struct GridwiseGemmMultipleD_xdl_cshuffle }); } } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K(K, N, StrideB); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4f37049b203b9f716e50fef2d13b72b2e345aba4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -0,0 +1,1077 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ComputeType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 98331d854490abe4e7f478d817104218cb49c28f..f760feb2ed22015567ef66648d0d1afa81bdff9d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -1,8 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include + #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636de93a13cbe7f851575b4abaedd99af72..754a3e89c9327792fb370e30fbacdee2b9758163 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index 3281b910d3d8782fc18f91ee77ce42aa44fc5ba9..25e1cebdbe06918cfea98132209116a2d74de822 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -9,13 +9,13 @@ namespace ck { struct GridwiseGemmPipeline_v2 { - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) { // TODO: improve applicability return num_loop % 2 == 0; } - __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) { return (num_loop / 2) > 1; } @@ -79,6 +79,10 @@ struct GridwiseGemmPipeline_v2 do { +#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT + __builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT); +#endif + block_sync_lds(); // GEMM i diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ced62241cd29b54ba8b6a69c87896268b9b49a7e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v3 +{ + __host__ __device__ static constexpr bool IsSupported(index_t) + { + // TODO: improve applicability + return true; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + + while(num_loop > 0) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + block_sync_lds(); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + } + // tail + { + block_sync_lds(); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index a3f5324713c2d61e94434370ee9040ab12475279..d75b631e619ac4aac80bcc59d8f6b8f4d42c2612 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,7 +55,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940)) + defined(__gfx940) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -164,8 +164,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -318,8 +318,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using ReduceGridDescriptor_MBlock_MPerBlock = remove_cvref_t; @@ -456,6 +457,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index aa89bff9ee21873fdc57c839eb800be455005a9a..e7dc0d3eb0d9296f4e9ca98e68d0cadfb3cbcfb1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -375,10 +375,12 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle remove_cvref_t; using DefaultBGridDesc_BK0_N_BK1 = remove_cvref_t; - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; @@ -586,6 +588,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -1010,6 +1013,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ABDataType, + ABDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..caf8f040f4a78232386a1f0c3d0be02b4dc63295 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -0,0 +1,1076 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp index 2d3a36fca08d7811e406a89b398f401b741a7649..de5a4241986ff64e1797f2304da117178059ea4b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 397ae1c1b9b7b9763b7eb39855ddf4c752db949d..d8b31311b1040f6c77e955550875a2a3e1337530 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -138,8 +138,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -308,8 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2CTileMap = remove_cvref_t; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 1213cdc263909c6e15c134d5d7962aeb02f84a91..d700314a26062471c24974d39e0a576840a40b75 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,73 +17,64 @@ namespace ck { +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + typename GridwiseGemm::Problem problem) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_ctile_map); + GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_ctile_map; + ignore = problem; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } -template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -129,35 +122,396 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> - static constexpr auto AK0 = Number{}; - static constexpr auto BK0 = Number{}; - static constexpr auto AK1 = Number{}; - static constexpr auto BK1 = Number{}; + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy return make_naive_tensor_descriptor( - make_tuple(AK0, Number{}, AK1), - make_tuple(Number{} * AK1, AK1, I1)); + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); } - __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() { // B matrix in LDS memory, dst of blockwise copy return make_naive_tensor_descriptor( - make_tuple(BK0, Number{}, BK1), - make_tuple(Number{} * BK1, BK1, I1)); + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); } - __host__ __device__ static constexpr auto - GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() { constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); @@ -172,14 +526,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { // LDS allocation for A and B: be careful of alignment constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto a_block_space_size_aligned = math::integer_least_multiple( a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); @@ -194,42 +548,108 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(FloatAB), + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), c_block_size * sizeof(FloatCShuffle)); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ static constexpr bool CheckValidity(const Problem& problem) { static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); - const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); - const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } - if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) - return false; + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - return false; + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } - // check gridwise gemm pipeline - const auto num_k_loop = K / KPerBlock; + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } - if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + if constexpr(is_same::value) { - return false; + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -238,22 +658,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return true; } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } - __host__ __device__ static constexpr auto - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), @@ -265,33 +680,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) - { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); - } - - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; - - using DefaultBlock2CTileMap = - remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap& block_2_ctile_map) + const Problem& problem) { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -299,7 +707,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N}; + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -319,7 +733,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1, BK1); + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -333,11 +747,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatA, + ComputeTypeA, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -364,11 +778,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + ComputeTypeB, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -397,11 +811,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // register // sanity check constexpr index_t KPack = math::max( - math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatAB, + ComputeTypeA, + ComputeTypeB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -419,14 +835,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); // gridwise GEMM pipeline static_assert(std::is_default_constructible_v); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 2d4ebe7076408b29cffb863630e903a195ceb6de..013120c5406c04367f6a0fd654a6bd300382f3cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -58,7 +58,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue @@ -173,8 +173,8 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -345,8 +345,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_m_n); } - using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using C0GridDescriptor_NBlock_NPerBlock = remove_cvref_t; @@ -494,6 +495,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatAB, + FloatAB, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp index acece0fbba42bc322cb3102aaf8898a77b41b9d0..8675a9242acf0d573cc0316f3f89b95b654b38af 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -330,8 +330,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle return e_grid_desc_mblock_mperblock_nblock_nperblock; } - using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; using DefaultBlock2ETileMap = remove_cvref_t; @@ -493,6 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< TileMathThreadGroupSize, ABDataType, + ABDataType, FloatGemmAcc, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index 1979331d074cec548eac549b2690e853eeffd44b..06c87d1892569be92d64a01a689c17a4570baad6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -139,7 +139,8 @@ __host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLen } template (p_a_grid, @@ -181,21 +182,22 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = ComputeTypeA> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight { static constexpr auto I0 = Number<0>{}; @@ -259,17 +263,22 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update - // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) - using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file +#if CK_WORKAROUND_DENORM_FIX + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; #else - using FloatABAdjusted = FloatAB; + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; #endif // M0/M1/M1Padding @@ -506,7 +515,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), c_block_size * sizeof(FloatC)); } @@ -610,8 +620,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, @@ -673,8 +683,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatA, + FloatAAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -703,8 +713,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + FloatB, + FloatBAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -733,11 +743,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size, + static_cast(p_shared) + a_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 8d86f3c1d75190bd2c285044c234d17f41e7f7d2..b12bcee0f414c6e56b6550c084630b86ab153bb5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -45,7 +45,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2b1814c03bd609c83a8fe2bbd9f50bbaa61d286c --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -0,0 +1,1185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp" +#include "ck/utility/workgroup_barrier.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC, + block_mapping, + static_cast(p_shared)); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_workspace; + ignore = M; + ignore = N; + ignore = K; + ignore = StrideA; + ignore = StrideB; + ignore = StrideC; + ignore = block_mapping; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + static constexpr auto KPerBlock = K0PerBlock * K1; + + using ThisThreadBlock = ThisThreadBlock; + using FloatAcc = FloatAcc_; + using FloatCShuffle = FloatAcc; + + using Block2CTileMap = Block2CTileMap_; + using FloatAB = FloatAB_; + using FloatC = FloatC_; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + Block2CTileMap block_mapping; + + Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + uint32_t num_cu, + uint32_t occupancy, + uint32_t num_sk_blocks_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; } + + __host__ __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA) + { + const index_t K0 = CalculateK0(KPad); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB) + { + const index_t K0 = CalculateK0(KPad); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + + __host__ __device__ static constexpr auto GetClusterLengthReduction() + { + // TODO: assume C is row major + // TODO: we always first loop over N, then M + constexpr auto NPerBlockPow2 = math::next_power_of_two(); + constexpr auto NPerBlockReduction = + NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL; + constexpr auto MPerBlockReduction = + (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction; + return Sequence{}; + } + + __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor() + { + const auto c_partial_acc_block_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(I1, MPerBlock)); + } + }(); + return c_partial_acc_block_m_n; + } + + using CGridDesc_M_N = remove_cvref_t; + + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + Block2CTileMap block_mapping, + void* __restrict__ p_shared_block) + { + uint32_t m = M; + uint32_t n = N; + uint32_t k = K; + uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock; + uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock; + uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock; + uint32_t stride_a = StrideA; + uint32_t stride_b = StrideB; + uint32_t stride_c = StrideC; + + const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); + const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const AElementwiseOperation a_element_op = AElementwiseOperation{}; + const BElementwiseOperation b_element_op = BElementwiseOperation{}; + const CElementwiseOperation c_element_op = CElementwiseOperation{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = static_cast(p_shared_block); + FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); + + uint32_t block_idx = block_mapping.get_block_idx(); + bool is_sk_block = block_idx < block_mapping.sk_num_blocks; + bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx && + block_idx < block_mapping.reduction_start_block_idx; + bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx; + bool is_padding_block = block_idx >= block_mapping.sk_num_blocks && + block_idx < block_mapping.dp_start_block_idx; + uint32_t iter_start, iter_end; + block_mapping.get_block_itr(block_idx, iter_start, iter_end); + uint32_t total_iter_length = iter_end - iter_start; + + if(is_padding_block) + return; + + uint32_t* p_semaphore = + reinterpret_cast(reinterpret_cast(p_workspace) + + block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc))); + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = + math::integer_divide_ceil(Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = make_multi_index( + 0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_n_reverse = + make_multi_index(0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = make_multi_index( + 0, + 0, + 0, + cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_n_reverse = + make_multi_index(0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; + auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + FloatAcc, // SrcData, + FloatAcc, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, // SrcData, + FloatC, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL), + CElementwiseOperation{}}; + + // block synchronization + wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); + +#if 0 + if(threadIdx.x == 0) { + printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast(blockIdx.x), + reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), + __builtin_amdgcn_readfirstlane(spatial_idx[I0]), + __builtin_amdgcn_readfirstlane(spatial_idx[I1])); + } +#endif + + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + return; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock * + NPerBlock; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length)); + uint32_t tile_idx, iter_offset; + block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + const index_t num_k_block_main_loop = current_iter_length; + + gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_k0_n_k1_grid_desc, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle(); + + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle(); + + auto c_block_buf = make_dynamic_buffer( + reinterpret_cast(p_shared_block), + c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_unmerge_transform( + make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform( + make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatCShuffle, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + 0, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + 0), + c_element_op}; + + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatCShuffle, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + c_block_copy_lds_to_global.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + // LDS to global + if(is_dp_block) + c_block_copy_lds_to_global.template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + else if(is_sk_block) + { + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } + else if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + c_block_copy_lds_to_global + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + } + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + mxdlperwave_forward_step); + } + }); + + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(tile_idx); + } + } + } + + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + + template + struct LStr + { + static std::string Get() { return ""; } + }; + + template <> + struct LStr + { + static std::string Get() { return "R"; } + }; + + template <> + struct LStr + { + static std::string Get() { return "C"; } + }; + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlStreamK_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 775b77118c2f54a8c0fc67bf5b2d538417c019c0..2c42ac934561f51cabbdc192dc321b2b975a67fa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" @@ -21,30 +22,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -53,22 +48,49 @@ __global__ void p_shared, a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - a_element_op, - b_element_op, - c_element_op, - block_2_ctile_map); + c_grid_desc_m_n); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; ignore = a_grid_desc_k0_m_k1; ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = block_2_ctile_map; + ignore = c_grid_desc_m_n; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); +#else + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -77,9 +99,6 @@ template ; - using GridwiseGemmPipe = remove_cvref_t())>; + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + template + __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1); + } + + template + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + K0{CalculateK0(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "K0:" << K0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t K0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; // denorm test fix, required to work around fp16 mfma issue // we convert fp16->fp32->bf16 and execute bf16 mfma instruction // when mfma if fixed, remove this section and update // FloatABAdjusted -> FloatAB throughout this file -#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) +#if CK_WORKAROUND_DENORM_FIX using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; #else using FloatABAdjusted = FloatAB; @@ -204,13 +322,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + const CGridDesc_M_N& c_grid_desc_m_n) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -239,7 +355,22 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return false; } - if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + // check gridwise gemm pipeline + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; } @@ -248,15 +379,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return true; } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + template __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n) { constexpr auto max_lds_align = K1; @@ -292,6 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + template + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) { - return BlockToCTileMap_M00_N0_M01Adapt( - c_grid_desc_m_n); - } + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const Block2CTileMap& block_2_ctile_map) - { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -338,7 +463,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; // divide block work by [M, N] const auto block_work_idx = @@ -440,6 +570,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, FloatABAdjusted, + FloatABAdjusted, FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), @@ -467,6 +598,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // gridwise GEMM pipeline + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, @@ -565,4 +697,346 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } }; +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext + : GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + using Parent = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + using typename Parent::GridwiseGemmPipe; + using typename Parent::Problem; + + using Parent::I1; + + using Parent::K1; + + __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 55f465a037f438287d176cb131b3ae96c0d5138d..19fbee727f77ae8b6b6f90a5a6f520accde95332 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,7 +43,7 @@ __global__ void const CBlockClusterAdaptor c_block_cluster_adaptor) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index f0ce2e3bdbc8ea835a7209c067340cc62a278d3f..d969b3f65703fcd703bdc5342e05013c1c06736f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,85 +8,64 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { template + typename CElementwiseOperation> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - static_cast(p_shared_block), - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = karg; + ignore = b2c_map; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; - ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template + typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeType = FloatC> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -127,10 +109,229 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0(K0_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0:" << K0 << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0 = CalculateK0(K, K_Batch); + return K_Batch * K0 * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -175,58 +376,165 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), c_block_size * sizeof(FloatC)); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); - const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); - - if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && - K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && - K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && - KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) - return false; + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; +#endif // DEBUG_LOG + return false; + } + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } - if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg N (" << karg.N + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { +#if DEBUG_LOG + std::cout + << "Arg M (" << karg.M + << ") value is not a multiple of CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + const auto num_k_loop = karg.K0 / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { +#if DEBUG_LOG + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0: " << karg.K0 << ", K0PerBlock: " << K0PerBlock << " " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; +#endif // DEBUG_LOG return false; } - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) { - const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } - return has_main_k0_block_loop; + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const index_t num_loop = K0 / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } + template __host__ __device__ static constexpr auto - MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) { const auto M = c_m_n_grid_desc.GetLength(I0); const auto N = c_m_n_grid_desc.GetLength(I1); @@ -243,10 +551,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } // return block_id to C matrix tile idx (m0, n0) mapping + template __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( c_m_n_grid_desc, 8, KBatch); } @@ -263,24 +572,37 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Number{})); } - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, + template + __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block, - const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const CBlockClusterAdaptor& c_block_cluster_adaptor) + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) { + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -288,28 +610,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - - // divide block work by [M, N] + // divide block work by [KBatch, M, N] const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - if(!c_block_cluster_adaptor.ValidCTileIndex( - make_tuple(block_work_idx[I1], block_work_idx[I2]), + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) { return; } - const index_t k_batch_id = block_work_idx[I0]; + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); // lds max alignment constexpr auto max_lds_align = K1; @@ -387,8 +709,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Sequence<1, K0PerBlock, MPerBlock, K1>, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatA, + ComputeType, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -417,8 +739,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Sequence<1, K0PerBlock, NPerBlock, K1>, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, + FloatB, + ComputeType, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -446,17 +768,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // register // sanity check - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -464,8 +788,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block = static_cast(p_shared_block); - FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + ComputeType* p_a_block = static_cast(p_shared_block); + ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); @@ -475,51 +799,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto b_block_buf = make_dynamic_buffer( p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); - // preload data into LDS - { - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); - - a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - } - - // Initialize C - c_thread_buf.Clear(); - - // main body - if constexpr(HasMainKBlockLoop) - { - index_t k0_block_data_begin = 0; - - do - { - a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); - b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); - - a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); - - block_sync_lds(); - - b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - - block_sync_lds(); - - a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); - - k0_block_data_begin += K0PerBlock; - } while(k0_block_data_begin < (K0 - K0PerBlock)); - } - - // tail - { - block_sync_lds(); - - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - } + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // output: register to global memory { @@ -648,7 +949,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 {c_block_desc_mblock_mperblock_nblock_nperblock, make_multi_index(0, 0, 0, 0), c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + make_multi_index(block_m_id, 0, block_n_id, 0), c_element_op}; constexpr auto mxdlperwave_forward_step = @@ -717,6 +1018,30 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 }); } } + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlSplitKCShuffle_" + << getGemmSpecializationString(GemmSpec) << "_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 8259927fecefa8bacd416289530c5b359b9aa5a4..b766b70a6729434b8801ef1a289f3524ad8f4adb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -139,8 +139,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { @@ -315,8 +315,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -451,6 +451,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, // typename ThreadClusterArrangeOrder, FloatCShuffle, // typename SrcData, FloatC, // typename DstData, - decltype( - c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), - decltype( - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index 5d5fdae170b17970b6105406b8e19889dd766549..fbe4dd409d616b8f0aa875eb7e1afde9a76efca5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -50,7 +50,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -471,6 +471,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index dc83f8e98493d209104fd4cbdd7f1254b2de72a2..1dc8d31efe6a8b6a885f63a34b6b93643bb4337f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -54,7 +54,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__)) + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t())>; + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { @@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = - remove_cvref_t; using DefaultBlock2CTileMap = @@ -489,6 +489,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, // typename DimAccessOrder, 5, // index_t VectorDim, CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index de1ae915920503a086d2cf939ae6e41d27a28838..61d0f9e0d5573c774bbb1dcba707e8c2e775f075 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c41eef8c45fa264a85237c116d223db8e42c48b2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation elementwise_op) +{ + GridwisePutElementwise1dFunctor::Run( + in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op); +} + +// output[indices] = input +template +struct GridwisePutElement_1D +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation& elementwise_op) + { + // Global Memory + const auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_1d_desc.GetElementSpaceSize()); + + const auto indices_global_buf = + make_dynamic_buffer(p_indices_global, + in_grid_1d_desc.GetElementSpaceSize(), + NumericLimits::Lowest()); + + // VGPR + StaticBuffer in_thread_buf; + StaticBuffer indices_thread_buf; + + // Thread id, Block id and index + const index_t thread_global_id = get_thread_global_1d_id(); + const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize); + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * InVectorSize; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + auto indices_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + index_t num_iter = M / loop_step; + do + { + in_global_load.Run(in_grid_1d_desc, + in_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf); + + in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}( + [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); }); + + indices_global_load.Run(in_grid_1d_desc, + indices_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + indices_thread_buf); + + indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}([&](auto iM) { + if(indices_thread_buf[iM] >= 0) + { + if constexpr(MemOp == InMemoryDataOperationEnum::Set) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) = + ck::type_convert(in_thread_buf[iM]); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd) + { + atomic_add(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax) + { + atomic_max(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::Add) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) += + ck::type_convert(in_thread_buf[iM]); + } + else + { + static_assert(MemOp == InMemoryDataOperationEnum::Set || + MemOp == InMemoryDataOperationEnum::AtomicAdd || + MemOp == InMemoryDataOperationEnum::AtomicMax || + MemOp == InMemoryDataOperationEnum::Add); + } + } + }); + + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp index 901e7aee98a7b6ab1c37b7dd4dc36f796632fef6..41352fabeb361b017949738bda31649a3ce0dd40 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp index 88c7b6acfeb9c43ec0e8b5a5ab557f834430fb40..0ad36b418a42bff56eb87ba35bb5150808e9cac4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp index 0344e68305b1a951cfc41a53666090bd1e0113b6..5f56ac6fc4664be1685a950dc6e242b3e256e178 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index ff2511fa6e61eda6c4cca976d0372cf98840ee5c..287b4e5421682eb3596f0d47b46bf769d32377b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( Number{}, Number{}))); - using ThreadwiseWolfordDescReduce = decltype( - make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}))); using ThreadwiseWelford = ThreadwiseWelford; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f77ffff350b5c3f28fc79db9af06202ac6d8d584 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, + p_out_global, + batch_count, + block_2_tile_map, + compute_ptr_offset_of_batch); +#else + ignore = in_grid_desc; + ignore = p_in_global; + ignore = out_grid_desc; + ignore = p_out_global; + ignore = block_2_tile_map; +#endif +} + +template +struct GridwiseTensorRearrange +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThisThreadBlock = ThisThreadBlock; + + __device__ static void Run(const InputGridDesc& in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc& out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap& block_2_tile_map, + const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch) + { + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); + + auto copy_global_to_global = + ThreadGroupTensorSliceTransfer_v7, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(DstInMemOp)>, + Sequence, + ThreadClusterLengths, + Sequence<0, 1>, + Sequence<0, 1>, + I1, + ScalarPerVector, + Sequence, + Sequence>{ + in_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + out_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + tensor_operation::element_wise::PassThrough{}}; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + // Global Memory + const index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + const auto in_global_buf = make_dynamic_buffer( + p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); + auto out_global_buf = make_dynamic_buffer( + p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize()); + + copy_global_to_global.Run( + tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); + } + + __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc, + const OutputGridDesc& out_grid_desc) + { + if(in_grid_desc.GetLength(I0) % MPerBlock != 0 || + in_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + if(out_grid_desc.GetLength(I0) % MPerBlock != 0 || + out_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + return true; + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp similarity index 80% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp index 792ffabcb90c71b4620915572c29091b67c8c4d6..1bee1b93f9763efd6d5a92fddf65fbef8c62cb5c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,9 +18,11 @@ template struct GridwiseNormalizationNaiveVariance_mk_to_mk { @@ -45,6 +48,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -66,6 +73,10 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -84,6 +95,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk reduce::Add, true>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -98,12 +111,16 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { // LDS @@ -115,6 +132,12 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer& var_thread_buf = mean_square_thread_buf; + StaticBuffer& + inv_std_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -228,6 +253,42 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -243,7 +304,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // E(x), E[x^2], var(x) // FIXME: Should not hack the transform from deviceOP - int reduce_length = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + ComputeDataType reduce_length = type_convert( + x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); static_for<0, MThreadSliceSize, 1>{}([&](auto I) { mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); @@ -302,10 +364,34 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -314,7 +400,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -404,8 +490,30 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // var(x) = E[x^2] - E[x]^2 var_thread_buf(I) = mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -437,7 +545,6 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -446,7 +553,7 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp similarity index 73% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp index 37795fa5694ae402af3225c11ffdbe16a6bc086f..8157b4fbc39cc1ce9a3958ce2036e094b0d7bdcd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -1,10 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp" namespace ck { template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - ComputeDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const YElementwiseOperation y_elementwise_op) + typename GridDesc_M_K, + typename GridDesc_M> +__global__ void +kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + const GridDesc_M save_mean_grid_desc_m, + const GridDesc_M save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) { GridwiseReduction::Run(x_grid_desc_m_k, gamma_grid_desc_m_k, beta_grid_desc_m_k, y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, num_k_block_tile_iteration, epsilon, p_x_global, p_gamma_global, p_beta_global, p_y_global, + p_save_mean_global, + p_save_inv_std_global, y_elementwise_op); }; @@ -44,9 +55,11 @@ template auto NormalizationKernelSelector(bool isSweepOnce) { @@ -68,9 +82,11 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, GridDesc_M_K, + GridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, @@ -84,15 +100,18 @@ auto NormalizationKernelSelector(bool isSweepOnce) BetaSrcVectorSize, YDstVectorDim, YDstVectorSize, + SaveMeanInvStdDstVectorSize, false>; using GridwiseNormalizationSweepOnceNaive = GridwiseNormalizationNaiveVariance_mk_to_mk; using GridwiseNormalizationGenericWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; using GridwiseNormalizationSweepOnceWelford = GridwiseNormalizationWelfordVariance_mk_to_mk; if constexpr(UseWelford) @@ -159,17 +185,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } else { @@ -178,17 +208,21 @@ auto NormalizationKernelSelector(bool isSweepOnce) GammaDataType, BetaDataType, YDataType, + SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, - GridDesc_M_K> + GridDesc_M_K, + GridDesc_M> : kernel_normalization; + GridDesc_M_K, + GridDesc_M>; } } diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp new file mode 100644 index 0000000000000000000000000000000000000000..80e9a84f9681f6c630db441a7f2de33826f19cf9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK1st +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static int + GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + { + bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; + + if(is_rightmost_block) + { + int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); + int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); + int kPerThread = kRightmostBlock < K_BlockTileSize + ? 0 + : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); + int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + else + { + int kPerBlock = math::integer_divide_ceil(k, kGridSize); + return KThreadSliceSize * (kPerBlock / K_BlockTileSize); + } + } + + // Calculate mean and variance by welford along k dimension + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const p_mean_global, + MeanVarDataType* const p_variance_global, + int32_t* const p_welford_count_global) + { + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize)); + + auto mean_var_count_store_index = make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1), + kRaw, + k_grid_size, + block_k_cluster_id, + thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + int welford_count = 0; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + + // The value of count is same for all I + if constexpr(I == MThreadSliceSize - 1) + welford_count = count; + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_grid_desc_m_kblock, + mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_grid_desc_m_kblock, + var_global_val_buf); + + if(block_m_cluster_id == 0 && thread_m_cluster_id == 0) + p_welford_count_global[block_k_cluster_id] = welford_count; + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e380f963801460e93684ba4a5cf3244c3fef1cb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -0,0 +1,499 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK2nd +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock& count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_mean_thread_buf; + StaticBuffer + in_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + StaticBuffer + welford_count_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; + + // IO + auto threadwise_mean_var_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_count_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + count_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_k = + make_multi_index(I0, KThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k) + { + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_mean_thread_buf); + + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_var_thread_buf); + + threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_mean_thread_buf, + in_var_thread_buf, + in_welford_count_thread_buf, + mean_thread_buf, + var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow( + mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k); + threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock, + mean_var_count_thread_copy_step_I0_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + + inv_std_thread_buf(I) = + type_convert(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + // step2: save mean and inverse std for backward (optional) + if(block_k_cluster_id == 0 && thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // step3: normalization + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } // end for (normalization) + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp similarity index 79% rename from include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp rename to include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp index 3a7ae459e5f6e8d58edc49ae53c4309030e911f9..15b412fba4cc0dd3d33e6b9a4f663c94a2053bda 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,9 +16,11 @@ template struct GridwiseNormalizationWelfordVariance_mk_to_mk { @@ -43,6 +46,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + static_assert(XSrcVectorSize == YDstVectorSize); static_assert(XSrcVectorSize == GammaSrcVectorSize); static_assert(XSrcVectorSize == BetaSrcVectorSize); @@ -64,6 +71,10 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = @@ -77,6 +88,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk ThreadClusterLengths_M_K, ThreadClusterArrangeOrder>; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -114,17 +127,18 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, index_t num_k_block_tile_iteration, ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op) { - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer var_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -226,6 +241,42 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk thread_k_cluster_id * YDstVectorSize), y_elementwise_op); + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); @@ -239,6 +290,15 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto beta_global_val_buf = make_dynamic_buffer( p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + auto threadwise_welford = ThreadwiseWelford(); threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); @@ -279,10 +339,33 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); }); + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -291,7 +374,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma & beta y_thread_buf(iK0)(Number{}) = @@ -360,8 +443,29 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk int count = threadwise_welford.cur_count_; BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); }); + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; @@ -393,7 +497,6 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = @@ -402,7 +505,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk // normalize y_thread_buf(iK0)(Number{}) = (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * - divisor; + inv_std_thread_buf(iM); // gamma y_thread_buf(iK0)(Number{}) = diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index 188c62d93b0d60317ac139115b79f496612d2ea6..c6eecc067d9f77fa488e7a922fe1fb711c0c0d0f 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp index 94cdfe01087bc4005d45f11992a6a8f94d42691c..44730d551c5c8c9fb2408502ece0d45c6f37a221 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp index e045e3b545a9897f5df9df1948123e86424d7c7b..e97aa433a6f3b9d54fe76ed26b5b0f6bd439ae84 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP #define CK_THREADWISE_GEMM_DLOPS_V3_HPP diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp index 0a1197a1630c12c53afa9513f10fcfa907ba82d3..6774a35bcb872466d272da0bccd618709d3fb29b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 16484ddcc5a89792c1331a48516e8317b36f2633..27742140795c0b3c37b908d232b17a67482d893f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); - // apply type convert - dst_vector.template AsType()(i) = type_convert(v); + dst_vector.template AsType()(i) = v; }); const bool is_dst_valid = @@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); // apply type convert - dst_buf(Number{}) = type_convert(v); + dst_buf(Number{}) = v; }); }); } diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index cba06f8e87520fea3f15640ca2536d5d8d0ca58c..0b2300fbe5fc3a2e15d8a66b81c98c731e45b7d2 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -6,8 +6,10 @@ #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" namespace ck { @@ -128,6 +130,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = @@ -207,19 +212,44 @@ struct ThreadwiseTensorSliceTransfer_v3r1 auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - // apply SrcElementwiseOperation on src_vector_container - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - SrcData src_v; + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; - src_element_op_(src_v, src_vector_container.template AsType()[i]); + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - src_vector_container.template AsType()(i) = src_v; + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + src_element_op_(op_r_v.template AsType()(idx), + src_vector_container.template AsType()[idx]); }); // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + .template SetAsType(src_data_idx_seq, + op_r_v.template AsType()[I0]); constexpr auto move_on_dim = [&]() constexpr { @@ -272,19 +302,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1 { #if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // TODO make this logic more generic for more sub-dword datatype if constexpr(SrcVectorDim != DstVectorDim && - ((is_same>::value && - is_same>::value && + ((is_same>::value && SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || - (is_same>::value && - is_same>::value && + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does @@ -318,8 +344,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); - // TODO type_convert is not used yet!!!!! - using src_vector_t = vector_type_maker_t; + using src_vector_t = vector_type_maker_t; using dst_vector_t = vector_type_maker_t; // get DstScalarPerVector # of read-only references to src vectors from @@ -342,17 +367,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number{}); // do data transpose - // TODO type_convert is not used yet!!!!! - transpose_vectors{}( + transpose_vectors{}( src_vector_refs, dst_vector_refs); }); } else { static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = - type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); } #endif @@ -769,11 +791,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; using DstThreadScratch = StaticTensorTupleOfVectorBuffer -struct lambda_scalar_per_access_for_src_and_dst -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; - -} // namespace detail - -// Assume: -// 1. src_desc and dst_desc are not known at compile-time -// 2. SrcBuffer and DstBuffer are DynamicBuffer -// 3. src_slice_origin and dst_slice_origin are not known at compile-time, -// 4. Use thread buffer -template // control whether to move back dst coordinate after each - // RunWrite(), will be fused with MoveDstSliceWindow to - // save addr computation -struct ThreadwiseTensorSliceTransfer_v3r3 -{ - static constexpr index_t nDim = SliceLengths::Size(); - using Index = MultiIndex; - - using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using Dst0Coord = decltype(make_tensor_coordinate(Dst0Desc{}, Index{})); - using Dst1Coord = decltype(make_tensor_coordinate(Dst1Desc{}, Index{})); - - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - using Dst0CoordStep = decltype(make_tensor_coordinate_step(Dst0Desc{}, Index{})); - using Dst1CoordStep = decltype(make_tensor_coordinate_step(Dst1Desc{}, Index{})); - - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r3( - const SrcDesc& src_desc, - const Index& src_slice_origin, - const SrcElementwiseOperation& src_element_op, - const DstDesc& dst_desc, - const Dst0Desc& dst0_desc, - const Dst1Desc& dst1_desc, - const Index& dst_slice_origin, - const DstElementwiseOperation& dst_element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), - dst0_coord_(make_tensor_coordinate(dst0_desc, dst_slice_origin)), - dst1_coord_(make_tensor_coordinate(dst1_desc, dst_slice_origin)), - src_element_op_(src_element_op), - dst_element_op_(dst_element_op) - { - } - - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) - { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); - } - - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, - const Dst0Desc& dst0_desc, - const Dst1Desc& dst1_desc, - const Index& dst_slice_origin_idx) - { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); - dst0_coord_ = make_tensor_coordinate(dst0_desc, dst_slice_origin_idx); - dst1_coord_ = make_tensor_coordinate(dst1_desc, dst_slice_origin_idx); - } - - template - __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) - { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - constexpr auto src_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - - using src_vector_type = vector_type_maker_t; - using src_vector_t = typename src_vector_type::type; - - // copy data from src_buf into src_vector_container - auto src_vector_container = src_vector_type{ - src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - - // apply SrcElementwiseOperation on src_vector_container - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - src_vector_container.template AsType()(i) = - src_element_op_(src_vector_container.template AsType()[i]); - }); - - // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_.template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } - } - - __device__ void TransferDataFromSrcThreadScratchToDstThreadScratch() - { -#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); - }); -#else - // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ - // TODO make this logic more generic for more sub-dword datatype - if constexpr(SrcVectorDim != DstVectorDim && - is_same>::value && - is_same>::value && - SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) - { - // each transpose does - // DstScalarPerVector # of src vectors in src_thread_scratch_ - // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ - constexpr index_t num_src_vector = Number{}; - constexpr index_t num_dst_vector = Number{}; - - // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose - // TODO: make this logic generic for all scenario - static_assert(SrcVectorDim != DstVectorDim, "wrong"); - - constexpr auto src_scalar_step_in_vector = generate_sequence( - detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto dst_scalar_step_in_vector = generate_sequence( - detail::lambda_scalar_step_in_vector{}, Number{}); - - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst{}, - Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - static_ford{}([&](auto access_idx) { - constexpr auto data_idx = access_idx * scalar_per_access; - - constexpr auto data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - // TODO type_convert is not used yet!!!!! - using src_vector_t = vector_type_maker_t; - using dst_vector_t = vector_type_maker_t; - - // get DstScalarPerVector # of read-only references to src vectors from - // src_thread_scratch_ - const auto src_vector_refs = generate_tie( - [&](auto i) -> const src_vector_t& { - // i increment corresponds to movement in DstVectorDim - return src_thread_scratch_.GetVectorTypeReference( - data_idx_seq + i * dst_scalar_step_in_vector); - }, - Number{}); - - // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ - auto dst_vector_refs = generate_tie( - [&](auto i) -> dst_vector_t& { - // i increment corresponds to movement in SrcVectorDim - return dst_thread_scratch_.GetVectorTypeReference( - data_idx_seq + i * src_scalar_step_in_vector); - }, - Number{}); - - // do data transpose - // TODO type_convert is not used yet!!!!! - transpose_vectors{}( - src_vector_refs, dst_vector_refs); - }); - } - else - { - static_ford{}([&](auto idx) { - // convert from SrcData to DstData here - dst_thread_scratch_(idx) = type_convert(src_thread_scratch_[idx]); - }); - } -#endif - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - const Dst0Desc& dst0_desc, - const Dst0Buffer& dst0_buf, - const Dst1Desc& dst1_desc, - const Dst1Buffer& dst1_buf) - { - // if there is transpose, it's done here - // TODO move this elsewhere - TransferDataFromSrcThreadScratchToDstThreadScratch(); - - static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - // src scalar per access on each dim - // TODO: don't use this - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make forward steps: dst0 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst0_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst0_desc, forward_step_idx); - }, - Number{}); - - // make forward steps: dst1 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst1_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst1_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); - - // make backward steps: dst0 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst0_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst0_desc, backward_step_idx); - }, - Number{}); - - // make backward steps: dst1 - // WARNING!!!!!!: this logic is only correct if dst/dst0/dst1 can use the same - // DstScalarPerVector - // TODO: fix this - const auto dst1_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst1_desc, backward_step_idx); - }, - Number{}); - - // loop over tensor and copy - static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - constexpr auto dst_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - const bool is_dst_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - - using dst_vector_type = vector_type_maker_t; - using dst_vector_t = typename dst_vector_type::type; - - // copy data from dst_thread_scratch_ into dst_vector_container - auto dst_vector_container = dst_vector_type{ - dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; - - // apply DstElementwiseOperation on dst_vector_container - static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - dst_vector_container.template AsType()(i) = - dst_element_op_(dst_vector_container.template AsType()[i]); - }); - - // copy data from dst_vector_container to dst_buf - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector_container.template AsType()[I0]); - - constexpr auto move_on_dim = [&]() constexpr - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - (); - - // move dst coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); - } - } - }); - }); - - // move dst coordinate back to slice origin (or not) - if constexpr(DstResetCoordinateAfterRun) - { - const auto dst_reset_step = - make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); - - move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); - } - } - - __device__ static constexpr auto GetSrcCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - // TODO: BUG: should start at 1 - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - constexpr auto I0 = Number<0>{}; - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // src_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, - const Index& src_slice_origin_step_idx) - { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); - } - - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, - const Dst0Desc dst0_desc, - const Dst1Desc dst1_desc, - const Index& dst_slice_origin_step_idx) - { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); - move_tensor_coordinate(dst0_desc, dst0_coord_, adjusted_step); - move_tensor_coordinate(dst1_desc, dst1_coord_, adjusted_step); - } - - __device__ static constexpr auto GetSrcThreadScratchDescriptor() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - __device__ static constexpr auto GetDstThreadScratchDescriptor() - { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = - generate_tuple([&](auto i) { return Sequence{}; }, Number{}); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - private: - static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; - static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - - StaticTensorTupleOfVectorBuffer - src_thread_scratch_; - - StaticTensorTupleOfVectorBuffer - dst_thread_scratch_; - - SrcCoord src_coord_; - DstCoord dst_coord_; - const SrcElementwiseOperation src_element_op_; - const DstElementwiseOperation dst_element_op_; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 6e8a23930bbb91677ee18bab216af6c45de72e4c..6a6c1f2ac5bfed7f02c5f5895d16002226cff56c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index f13da341f9b2d36d992b1cd1835b43c89dcabd1f..bd01108b03ca82d4877e19806521efda8da8c6c3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 9c91cd9ca8f86df37fee6dfdf597be1cefe2d683..644877d3931f08ea46b017446140484de8746fa9 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_vector_container.template AsType()[i]); // apply type convert - dst_vector_container.template AsType()(i) = type_convert(v); + dst_vector_container.template AsType()(i) = v; }); const bool is_dst_valid = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..88ed217547088b685a5291026190c74db9305da1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index 68bc2726f4be6fa1b08fe9213e0db9440526f975..cf2c7a2aee3d3f612fd12c0868477bfdfe49a2ea 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 0f5fb88b04540dbf355fb0f02998303b7680c193..b5847e51b42ea5b567133f0307b319fa3d8fadbb 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp index 2eb1b0ee90a446ee3b14354aff018275cf7809f4..db7dee21992d0e4edaa2e7a443b5c0ed68589107 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..299a4b9e7ded4b737227f8f111185cfb321a9df8 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,386 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7r2 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve>; + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using DstSpaceFillingCurve = SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto dst_vectors = generate_vectors(); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + dst_vectors_tuple_(iAccess) = dst_vectors; + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + + StaticallyIndexedArray dst_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp index 12ba2c5381311ab45ffad4be793651d176f5fc59..eb6715e8ebb382e438791f60193bfc23f06d1773 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a18443164822cc9ed5264eb142aa773866af0597 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_1x32x2 = 0, + dpp8_f16_2x16x2, + dpp8_f16_2x32x2, + dpp8_f16_4x16x2, + dpp8_f16_4x32x2, + dpp8_f16_8x16x2, + dpp8_f16_8x32x2, + dpp8_f16_16x16x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 1; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_1x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x32x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 0672bf8e5b2e15421f90bae642800a1b7ffcc0d3..979f3567e987ed66fd45af5deeae7bb2ac3fc761 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -262,12 +262,12 @@ struct wmma_type + class FloatC, + bool neg_a = false, + bool neg_b = false, + bool clamp = false> __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 319487bc053ccb7ae4598ca467836fd07196d7c8..835075b7f2c95c3a3630cdab874cbc90157e2dac 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -29,7 +29,15 @@ enum struct MfmaInstr mfma_i32_16x16x16i8, mfma_i32_32x32x16i8, mfma_i32_16x16x32i8, - mfma_f64_16x16x4f64 + mfma_f64_16x16x4f64, + mfma_f32_32x32x16f8f8, + mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, + mfma_f32_32x32x16f8bf8, + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8 }; template @@ -454,10 +462,192 @@ struct mfma_type } }; -template +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; + +template struct MfmaSelector { - template + template static constexpr auto GetMfma(); template <> @@ -594,7 +784,56 @@ struct MfmaSelector } #endif - static constexpr auto selected_mfma = mfma_type()>{}; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8bf8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } + + static constexpr auto selected_mfma = + mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -641,7 +880,8 @@ template + typename additional_type = base_type, + bool TransposeC = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -792,10 +1032,14 @@ struct XdlopsGemm template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "base base_type must be double, float, half, bfloat16, and int8_t!"); + static_assert( + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || + (is_same::value && is_same::value) || + (is_same::value && is_same::value), + "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) @@ -891,7 +1135,7 @@ struct XdlopsGemm return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; } - static constexpr auto mfma = MfmaSelector{}; + static constexpr auto mfma = MfmaSelector{}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp index 5fc11d9158ab9a1261c1285c2cd9f5208d7afbf3..ea27a40ce3c7a48564af112a97af342c157f746d 100644 --- a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 13d0a28cfe5b2db14a83a4b81e0bb379aabbe137..2be0b66812434eb5127f5a97534ca4dc36b68273 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,150 @@ namespace ck { namespace tensor_operation { +namespace { +template < + index_t NDimSpatial, + typename ALayout, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization> +constexpr auto make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& out_g_n_k_wos_strides) +{ + const auto KStride = Number<1>{}; + + if constexpr(is_same_v) + { + const index_t NStride = out_g_n_k_wos_strides[1]; + const index_t HiStride = out_g_n_k_wos_strides[3]; + const index_t WiStride = out_g_n_k_wos_strides[4]; + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WiStride, KStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K), + make_tuple(NStride, HiStride, WiStride, KStride)); + } + } + else if constexpr(is_same_v) + { + const index_t NStride = out_g_n_k_wos_strides[1]; + const index_t DoStride = out_g_n_k_wos_strides[3]; + const index_t HoStride = out_g_n_k_wos_strides[4]; + const index_t WoStride = out_g_n_k_wos_strides[5]; + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Do, Ho, Wo, K), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + } + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); + } +} + +template +constexpr auto make_wei_grid_desc( + const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C) +{ + + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + BLayout::name()); + } +} + +template +constexpr auto make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& in_g_n_c_wis_strides) +{ + + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[2])); + } + else if constexpr(is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(in_g_n_c_wis_strides[1], + in_g_n_c_wis_strides[3], + in_g_n_c_wis_strides[4], + in_g_n_c_wis_strides[5], + in_g_n_c_wis_strides[2])); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); + } +} + +} // namespace + template < index_t NDimSpatial, ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, @@ -20,6 +164,7 @@ template < index_t BK1, index_t GemmMPerBlock, index_t GemmNPerBlock, + index_t GemmKPerBlock, bool DoPadGemmM, bool DoPadGemmN> struct TransformConvBwdDataToGemm_v1 @@ -27,13 +172,30 @@ struct TransformConvBwdDataToGemm_v1 static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = + NDimSpatial == 2 ? Number{} : Number{}; + template , + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_AK0_M_AK1( const std::array& out_g_n_k_wos_lengths, - const std::array& /* out_g_n_k_wos_strides */, + const std::array& out_g_n_k_wos_strides, const std::array& wei_g_k_c_xs_lengths, const std::array& /* wei_g_k_c_xs_strides */, const std::array& in_g_n_c_wis_lengths, @@ -44,44 +206,52 @@ struct TransformConvBwdDataToGemm_v1 const std::array& /* input_right_pads */, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1]; - const index_t Hi = in_g_n_c_wis_lengths[3]; - const index_t Wi = in_g_n_c_wis_lengths[4]; - - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; + const index_t Hi = in_g_n_c_wis_lengths[HIdx]; + const index_t Wi = in_g_n_c_wis_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; + const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; + const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t AK0 = K / AK1; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; - // assume packed - const auto out_n_ho_wo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d + const auto out_grid_desc = + make_out_grid_desc( + N, Do, Ho, Wo, K, out_g_n_k_wos_strides); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { + const index_t AK0 = math::integer_divide_ceil(K, AK1); + // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), - make_tuple(make_pass_through_transform(N * Ho * Wo), + out_grid_desc, + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_unmerge_transform(make_tuple(AK0, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{})); @@ -96,103 +266,226 @@ struct TransformConvBwdDataToGemm_v1 } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto XDot = math::integer_divide_ceil(X, XTilde); + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); const auto IHTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); const auto IWTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); const auto IHTildeSliceEnd = math::min( HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); const auto IWTildeSliceEnd = math::min( WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); - // A: output tensor - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Ho, I0, I0), - make_pad_transform(Wo, I0, I0), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(YDot, HTilde), - make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, WTilde), - make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(NDimSpatial == 2) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(AK0, AK1))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - - const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)), - make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(AK1)), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto out_gemmak0_gemmm_gemmak1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - out_gemmak0_gemmmraw_gemmak1_grid_desc, - make_tuple(AK0, GemmMPerBlock, AK1), - Sequence{}); - - return out_gemmak0_gemmm_gemmak1_grid_desc; + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform( + make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform( + make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform( + make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } template , + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v), bool>::type = false> static auto MakeBDescriptor_BK0_N_BK1( const std::array& out_g_n_k_wos_lengths, @@ -207,35 +500,40 @@ struct TransformConvBwdDataToGemm_v1 const std::array& /* input_right_pads */, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t K = wei_g_k_c_xs_lengths[1]; const index_t C = wei_g_k_c_xs_lengths[2]; - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t BK0 = K / BK1; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; // assume packed - const auto wei_k_y_x_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + // k_y_x_c for 2d or k_z_y_x_c for 3d + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { + const index_t BK0 = math::integer_divide_ceil(K, BK1); + // B: weight tensor const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), @@ -243,7 +541,7 @@ struct TransformConvBwdDataToGemm_v1 make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1)); + make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1)); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( @@ -255,75 +553,175 @@ struct TransformConvBwdDataToGemm_v1 } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); const auto YDot = math::integer_divide_ceil(Y, YTilde); const auto XDot = math::integer_divide_ceil(X, XTilde); // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); // B weight tensor - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K), - make_embed_transform(make_tuple(YDot, YTilde), - make_tuple(ConvStrideH / GcdStrideDilationH, I1)), - make_embed_transform(make_tuple(XDot, XTilde), - make_tuple(ConvStrideW / GcdStrideDilationW, I1)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_slice_transform(YDot, I0, YDotSlice), - make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(i_ytilde), - make_freeze_transform(i_xtilde), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - - const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( - wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)), - make_pass_through_transform(C), - make_pass_through_transform(BK1)), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = - ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + if constexpr(NDimSpatial == 2) + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, make_tuple( - wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1), - Sequence{}); - - return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<3>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)), + make_pass_through_transform(C)), + make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor( + wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)), + make_pass_through_transform(C)), + make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemm_gemmbk1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } template || + is_same_v || is_same_v || + is_same_v || is_same_v), bool>::type = false> static auto @@ -339,153 +737,309 @@ struct TransformConvBwdDataToGemm_v1 const std::array& input_right_pads, const std::array& tildes) { - index_t i_ytilde = tildes[0]; - index_t i_xtilde = tildes[1]; + index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum]; + index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum]; + index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum]; const index_t N = in_g_n_c_wis_lengths[1]; const index_t C = wei_g_k_c_xs_lengths[2]; - const index_t Hi = in_g_n_c_wis_lengths[3]; - const index_t Wi = in_g_n_c_wis_lengths[4]; + const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1; + const index_t Hi = in_g_n_c_wis_lengths[HIdx]; + const index_t Wi = in_g_n_c_wis_lengths[WIdx]; - const index_t Ho = out_g_n_k_wos_lengths[3]; - const index_t Wo = out_g_n_k_wos_lengths[4]; + const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1; + const index_t Ho = out_g_n_k_wos_lengths[HIdx]; + const index_t Wo = out_g_n_k_wos_lengths[WIdx]; - const index_t Y = wei_g_k_c_xs_lengths[3]; - const index_t X = wei_g_k_c_xs_lengths[4]; + const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1; + const index_t Y = wei_g_k_c_xs_lengths[YIdx]; + const index_t X = wei_g_k_c_xs_lengths[XIdx]; - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; + const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum]; + const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum]; + const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum]; - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; + const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum]; + const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum]; + const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum]; - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; // assume strided - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), - make_tuple(in_g_n_c_wis_strides[1], - in_g_n_c_wis_strides[3], - in_g_n_c_wis_strides[4], - in_g_n_c_wis_strides[2])); + // n_hi_wi_c for 2d n_di_hi_wi_c for 3d + const auto in_grid_desc = + make_in_grid_desc(N, Di, Hi, Wi, C, in_g_n_c_wis_strides); if constexpr(ConvBwdDataSpecialization == ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { // C: input tensor - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), - make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); - - return in_gemmm_gemmn_grid_desc; + if constexpr(NDimSpatial == 2) + { + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + + // C: input tensor + const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_x_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple( + Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } else { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto ZTilde = ConvStrideD / GcdStrideDilationD; const auto YTilde = ConvStrideH / GcdStrideDilationH; const auto XTilde = ConvStrideW / GcdStrideDilationW; + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); - // only work on HTilde and WTilde that contribute to non-padding area of input tensor + // only work on DTilde, HTilde and WTilde that contribute to + // non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); const auto IHTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); const auto IWTildeSliceBegin = math::integer_divide_floor( math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); const auto IHTildeSliceEnd = math::min( HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); const auto IWTildeSliceEnd = math::min( WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; // C: input tensor - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(YTilde, HTilde), - make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(XTilde, WTilde), - make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_freeze_transform(i_ytilde), - make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(i_xtilde), - make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( - in_gemmmraw_gemmnraw_grid_desc, - make_tuple(GemmMPerBlock, GemmNPerBlock), - Sequence{}); - - return in_gemmm_gemmn_grid_desc; + if constexpr(NDimSpatial == 2) + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if(NDimSpatial == 3) + { + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 1b5e64b66cf292fbcf4979d95aa4a5b1fa3e3a8b..6f546f1d6d27d61239ff4c1c0d03b898be48f58c 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -1,6 +1,6 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -20,348 +20,13 @@ struct TransformConvFwdToGemm static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Wi = a_g_n_c_wis_lengths[3]; - - const index_t Wo = c_g_n_k_wos_lengths[3]; - - const index_t ConvStrideW = conv_filter_strides[0]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wo_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t X = b_g_k_c_xs_lengths[3]; - const index_t ConvDilationW = conv_filter_dilations[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - - const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); - - const auto in_n_wip_c_desc = transform_tensor_descriptor( - in_n_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - const auto in_n_x_wo_c_desc = transform_tensor_descriptor( - in_n_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Wo)), - make_merge_transform(make_tuple(X, C))), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Hi = a_g_n_c_wis_lengths[3]; - const index_t Wi = a_g_n_c_wis_lengths[4]; - - const index_t Ho = c_g_n_k_wos_lengths[3]; - const index_t Wo = c_g_n_k_wos_lengths[4]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Y = b_g_k_c_xs_lengths[3]; - const index_t X = b_g_k_c_xs_lengths[4]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const auto in_n_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmk_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)), - make_merge_transform(make_tuple(Y, X, C))), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - - template , - bool>::type = false> - static auto - MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& /* a_g_n_c_wis_strides */, - const std::array& b_g_k_c_xs_lengths, - const std::array& /* b_g_k_c_xs_strides */, - const std::array& c_g_n_k_wos_lengths, - const std::array& /* c_g_n_k_wos_strides */, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) - { - const index_t N = a_g_n_c_wis_lengths[1]; - const index_t C = a_g_n_c_wis_lengths[2]; - - const index_t Di = a_g_n_c_wis_lengths[3]; - const index_t Hi = a_g_n_c_wis_lengths[4]; - const index_t Wi = a_g_n_c_wis_lengths[5]; - - const index_t Do = c_g_n_k_wos_lengths[3]; - const index_t Ho = c_g_n_k_wos_lengths[4]; - const index_t Wo = c_g_n_k_wos_lengths[5]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) - { - const index_t NDoHoWo = - N * ck::accumulate_n( - c_g_n_k_wos_lengths.begin() + 3, NDimSpatial, 1, std::multiplies<>()); - - const auto in_gemmm_gemmk_desc = - make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C)); - - return in_gemmm_gemmk_desc; - } - else if constexpr(ConvForwardSpecialization == - device::ConvolutionForwardSpecialization::Filter1x1Pad0) - { - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_do_ho_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_pass_through_transform(C)), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - else - { - const index_t Z = b_g_k_c_xs_lengths[3]; - const index_t Y = b_g_k_c_xs_lengths[4]; - const index_t X = b_g_k_c_xs_lengths[5]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const auto in_n_di_hi_wi_c_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - - const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( - in_n_di_hi_wi_c_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Di, InLeftPadD, InRightPadD), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( - in_n_hip_wip_c_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), - make_tuple(Sequence<0>{}, - Sequence<1, 2>{}, - Sequence<3, 4>{}, - Sequence<5, 6>{}, - Sequence<7>{})); - - const auto in_gemmm_gemmk_desc = transform_tensor_descriptor( - in_n_z_do_y_ho_x_wo_c_desc, - make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), - make_merge_transform(make_tuple(Z, Y, X, C))), - make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return in_gemmm_gemmk_desc; - } - } - // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // properties template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, @@ -473,7 +138,8 @@ struct TransformConvFwdToGemm template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, @@ -601,7 +267,8 @@ struct TransformConvFwdToGemm template || - is_same_v), + is_same_v || + is_same_v), bool>::type = false> static auto MakeADescriptor_M_K(const std::array& a_g_n_c_wis_lengths, diff --git a/include/ck/utility/amd_address_space.hpp b/include/ck/utility/amd_address_space.hpp index 9f1525914cdee16ee6cbed51297f0008d46f6085..d54f70e750e8032f706ae3f7f0dfec9f7cab4c8d 100644 --- a/include/ck/utility/amd_address_space.hpp +++ b/include/ck/utility/amd_address_space.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index bdfb4f27580e011f0bebabfeae9146374a5ecd67..027aa72f18115365b80ab3df6c8b3b6521b9ebc3 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" @@ -286,505 +286,270 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -template -__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct AmdBufferCoherenceEnum { - static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); - - if constexpr(is_same::value) - { - // use fp32 load to mimic fp64 load - if constexpr(N == 1) - { - const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); - } - else if constexpr(N == 2) - { - const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); - } - else if constexpr(N == 4) - { - const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - const float4_t f32_1 = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - 0); - vector_type tmp; + DefaultCoherence = 0, // default value + GLC = 1, + SLC = 2, + GLC_SLC = 3, +}; - tmp.AsType()(Number<0>{}) = bit_cast(f32_0); - tmp.AsType()(Number<1>{}) = bit_cast(f32_1); +template +__device__ typename vector_type::type +amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); - return tmp.AsType()(Number<0>{}); - } + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 8) - { - vector_type tmp; - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - 0); - - return tmp.AsType()(Number<0>{}); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_fp16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_fp16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 8) - { - // use fp32 load to mimic fp16 load - float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); - return bit_cast(tmp); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i16x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 8) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); - return bit_cast(tmp); - } + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 16) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 2) - { - return llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 4) - { - return llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 8) - { - vector_type tmp; - - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - 0); - return tmp.AsType()(Number<0>{}); - } + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + return bit_cast(tmp); } - else if constexpr(is_same::value) + else if constexpr(N == 32) { - if constexpr(N == 1) - { - return llvm_amdgcn_raw_buffer_load_i8( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); -#else - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - return llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); -#else - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); -#endif - } - else if constexpr(N == 8) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - 0); - - return tmp.AsType()(Number<0>{}); -#else - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); -#endif - } - else if constexpr(N == 16) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - vector_type tmp; - - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int8_t), - 0); - - tmp.AsType()(Number<2>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int8_t), - 0); - - tmp.AsType()(Number<3>{}) = - llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int8_t), - 0); - - return tmp.AsType()(Number<0>{}); -#else - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - return bit_cast(tmp); -#endif - } + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); + + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + tmp.AsType()(Number<2>{}) = tmp2; + tmp.AsType()(Number<3>{}) = tmp3; + + return bit_cast(tmp); } } -template -__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) +template +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + using r_t = typename vector_type::type; + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + return bit_cast(raw_data); +} + +template +__device__ void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) { - // use fp32 store to mimic fp64 store - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 2) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 4) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 8) - { -#if 0 - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - 0); -#else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); -#endif - } + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 8) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bhalf_t), - 0); - } + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 16) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i32(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); } - else if constexpr(is_same::value) + else if constexpr(N == 32) { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i8(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); -#else - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); -#endif - } - else if constexpr(N == 4) - { -#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE - llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); -#else - llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); -#endif - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 16) - { - llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); } } +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); +} + template __device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, @@ -1012,7 +777,9 @@ __device__ void amd_buffer_atomic_max_impl(const typename vector_type::typ // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, @@ -1031,13 +798,13 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - - return amd_buffer_load_impl( + return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + #else - vector_t tmp = amd_buffer_load_impl( - src_wave_buffer_resource, src_thread_addr_offset, 0); + vector_t tmp = amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); #endif } @@ -1046,7 +813,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, @@ -1064,7 +833,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, constexpr index_t vector_size = scalar_type::vector_size; - vector_t tmp = amd_buffer_load_impl( + vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(customized_value); @@ -1074,7 +843,9 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. -template +template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, @@ -1092,13 +863,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { - amd_buffer_store_impl( + amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a28292dade3dcad8ddcd5bd3c0cd119aeb4b4253 --- /dev/null +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/inner_product_dpp8.hpp" + +namespace ck { + +namespace dpp8 { + +template +struct dpp_datatypes; + +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; + +} // namespace dpp8 + +} // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 1f7df70bcd565dcba4d0aee3a49aca32f835bb45..43baa817d3672c0b2eed6059a88351aeb19a5209 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_INLINE_ASM_HPP #define CK_AMD_INLINE_ASM_HPP diff --git a/include/ck/utility/amd_llvm_intrinsic.hpp b/include/ck/utility/amd_llvm_intrinsic.hpp deleted file mode 100644 index 01e77d7be897085b8c5816fe9e20fe924b859ba7..0000000000000000000000000000000000000000 --- a/include/ck/utility/amd_llvm_intrinsic.hpp +++ /dev/null @@ -1,14 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#ifndef CK_AMD_LLVM_INTRINSIC_HPP -#define CK_AMD_LLVM_INTRINSIC_HPP - -#include "data_type.hpp" - -namespace ck { - -__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); - -} // namespace ck -#endif diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp new file mode 100644 index 0000000000000000000000000000000000000000..741b2975af6c5bf99346b1460018eac6fa33b21b --- /dev/null +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/math.hpp" + +#include +#include +#include +#include + +namespace ck { +namespace detail { + +template +struct get_carrier; + +template <> +struct get_carrier<1> +{ + using type = uint8_t; +}; + +template <> +struct get_carrier<2> +{ + using type = uint16_t; +}; + +template <> +struct get_carrier<3> +{ + using type = class carrier + { + using value_type = uint32_t; + + std::array bytes; + static_assert(sizeof(bytes) <= sizeof(value_type)); + + // replacement of host std::copy_n() + template + __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + { + if(0 < size) + { + *to = *from; + ++to; + for(Size count = 1; count < size; ++count) + { + *to = *++from; + ++to; + } + } + + return to; + } + + // method to trigger template substitution failure + __device__ carrier(const carrier& other) noexcept + { + copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + } + + public: + __device__ carrier& operator=(value_type value) noexcept + { + copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + + return *this; + } + + __device__ operator value_type() const noexcept + { + std::byte result[sizeof(value_type)]; + + copy_n(bytes.begin(), bytes.size(), result); + + return *reinterpret_cast(result); + } + }; +}; +static_assert(sizeof(get_carrier<3>::type) == 3); + +template <> +struct get_carrier<4> +{ + using type = uint32_t; +}; + +template +using get_carrier_t = typename get_carrier::type; + +} // namespace detail + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +template < + typename Object, + typename = std::enable_if_t && std::is_trivially_copyable_v>> +__device__ auto amd_wave_read_first_lane(const Object& obj) +{ + using Size = unsigned; + constexpr Size SgprSize = 4; + constexpr Size ObjectSize = sizeof(Object); + + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) std::byte to_obj[ObjectSize]; + + constexpr Size RemainedSize = ObjectSize % SgprSize; + constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; + for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) + { + using Sgpr = detail::get_carrier_t; + + *reinterpret_cast(to_obj + offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + offset)); + } + + if constexpr(0 < RemainedSize) + { + using Carrier = detail::get_carrier_t; + + *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( + *reinterpret_cast(from_obj + CompleteSgprCopyBoundary)); + } + + /// NOTE: Implicitly start object lifetime. It's better to use std::start_lifetime_at() in this + /// scenario + return *reinterpret_cast(to_obj); +} + +} // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index bf091425485d75a244399f4781e1d44c4b38dd2c..dd7f0b770a11470037b1d33164d830cb9819e8d1 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index a742496fc1f2fcf2ef529bbf079c481bc8d99f03..afc066405e87c5c65811156ac51b2bb4bd452dda 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,10 +1,7 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#ifndef CK_AMD_XDLOPS_HPP -#define CK_AMD_XDLOPS_HPP - -#include "data_type.hpp" +#pragma once namespace ck { @@ -344,7 +341,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> template __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) { -#if defined(__gfx90a__) || defined(__gfx940__) +#if defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -354,5 +351,257 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> #endif } }; -} // namespace ck + +template +struct intrin_mfma_f32_32x32x16f8f8; + +template <> +struct intrin_mfma_f32_32x32x16f8f8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8f8; + +template <> +struct intrin_mfma_f32_16x16x32f8f8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); #endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16f8bf8; + +template <> +struct intrin_mfma_f32_32x32x16f8bf8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8bf8; + +template <> +struct intrin_mfma_f32_16x16x32f8bf8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +} // namespace ck diff --git a/include/ck/utility/array.hpp b/include/ck/utility/array.hpp index 370a457fe9d97a621e6f91e78f14733398ebf756..f63ce5e5a07a796888cb60ae8da0c855df75e7ff 100644 --- a/include/ck/utility/array.hpp +++ b/include/ck/utility/array.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP diff --git a/include/ck/utility/array_multi_index.hpp b/include/ck/utility/array_multi_index.hpp index 9b8d5b95e9f6241b574104313b5b8fa2984ab18d..c0c1ea65fc77dd625aaaf6456a6628e09df27b51 100644 --- a/include/ck/utility/array_multi_index.hpp +++ b/include/ck/utility/array_multi_index.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_ARRAY_MULTI_INDEX_HPP #define CK_ARRAY_MULTI_INDEX_HPP diff --git a/include/ck/utility/c_style_pointer_cast.hpp b/include/ck/utility/c_style_pointer_cast.hpp index 6e8b0081587afcb2db159a05ecd5fb940def68ff..610e393a77216500448c3682b5db9b9860a077da 100644 --- a/include/ck/utility/c_style_pointer_cast.hpp +++ b/include/ck/utility/c_style_pointer_cast.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 1378bbe448e2e1862fed04596fba7f4b106d03b7..f95660a8a47d2c93a65a497a22a74c4c0b6eaea4 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -24,6 +24,7 @@ #include "ck/utility/tuple.hpp" #include "ck/utility/tuple_helper.hpp" #include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" #include "ck/utility/magic_division.hpp" #include "ck/utility/c_style_pointer_cast.hpp" #include "ck/utility/is_known_at_compile_time.hpp" @@ -33,6 +34,7 @@ #include "ck/utility/debug.hpp" #include "ck/utility/amd_buffer_addressing.hpp" +#include "ck/utility/amd_wave_read_first_lane.hpp" #include "ck/utility/generic_memory_space_atomic.hpp" #include "ck/utility/get_id.hpp" #include "ck/utility/thread_group.hpp" diff --git a/include/ck/utility/container_element_picker.hpp b/include/ck/utility/container_element_picker.hpp index abc5185e04a5079edb95fe15bfcd6788256d384a..838147e420cc7c95b63273b32ae3b2be2429c017 100644 --- a/include/ck/utility/container_element_picker.hpp +++ b/include/ck/utility/container_element_picker.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index c8b02bc5acaf8542bf028fff39c42061db3b3296..9c7b954565d386a8fdecd21052b102e750ab7102 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 101061191ef33fd922e1777119718d20627efc11..d367ad8df54fe27dc42d99ab0189a2aa261e3bde 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -9,9 +9,9 @@ namespace ck { using bhalf_t = ushort; using half_t = _Float16; -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -using int4_t = _BitInt(4); -#endif +using int4_t = _BitInt(4); +using f8_t = _BitInt(8); +using bf8_t = unsigned _BitInt(8); // vector_type template @@ -142,7 +142,20 @@ struct scalar_type }; #endif -// +template <> +struct scalar_type +{ + using type = f8_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; + template struct vector_type { @@ -944,125 +957,21 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; -// Convert X to Y -template -__host__ __device__ constexpr Y type_convert(X x) -{ - static_assert(!std::is_reference_v && !std::is_reference_v); - - return static_cast(x); -} - -// convert bfp16 to fp32 -template <> -inline __host__ __device__ constexpr float type_convert(bhalf_t x) -{ - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(x) << 16}; - - return u.fp32; -} - -// convert fp32 to bfp16 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(float x) -{ - union - { - float fp32; - uint32_t int32; - } u = {x}; - - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - bool flag0 = ~u.int32 & 0x7f800000; - - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bfloat16's mantissa bits are all 0. - bool flag1 = !flag0 && (u.int32 & 0xffff); - - u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even - u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN - - return uint16_t(u.int32 >> 16); -} - -// convert bfp16 to fp16 via fp32 -template <> -inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert fp16 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int32 via fp32 -template <> -inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int32 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} - -// convert bfp16 to int8 via fp32 -template <> -inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) -{ - float x_fp32 = type_convert(x); - - return static_cast(x_fp32); -} - -// convert int8 to bfp16 via fp32 -template <> -inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) -{ - float x_fp32 = static_cast(x); - - return type_convert(x_fp32); -} +// f8 +using f8x2_t = typename vector_type::type; +using f8x4_t = typename vector_type::type; +using f8x8_t = typename vector_type::type; +using f8x16_t = typename vector_type::type; +using f8x32_t = typename vector_type::type; +using f8x64_t = typename vector_type::type; + +// bf8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; template struct NumericLimits @@ -1110,4 +1019,106 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } + + __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } + + __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } +}; + +template +struct NumericUtils +{ +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode +}; } // namespace ck diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 593bbb711672f7dc6e3db22d877bf249b5dff56d..80346f0d9f6f9e5a6a28dcab4e8a7666639394ca 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index c6f0d299ef3c35c63ad97424888545edd2874914..85756cd0d1daeb2757c6dab4f5f0c9abe3504c53 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -19,7 +19,8 @@ namespace ck { template + bool InvalidElementUseNumericalZeroValue, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> struct DynamicBuffer { using type = T; @@ -77,13 +78,16 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_zero, t_per_x>( + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_); } else { return amd_buffer_load_invalid_element_return_customized_value, - t_per_x>( + t_per_x, + coherence>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -136,10 +140,36 @@ struct DynamicBuffer } else if constexpr(Op == InMemoryDataOperationEnum::Add) { - auto tmp = this->template Get(i, is_valid_element); - this->template Set(i, is_valid_element, x + tmp); - // tmp += x; - // this->template Set(i, is_valid_element, tmp); + auto tmp = this->template Get(i, is_valid_element); + using scalar_t = typename scalar_type>::type; + // handle bfloat addition + if constexpr(is_same_v) + { + if constexpr(is_scalar_type::value) + { + // Scalar type + auto result = + type_convert(type_convert(x) + type_convert(tmp)); + this->template Set(i, is_valid_element, result); + } + else + { + // Vector type + constexpr auto vector_size = scalar_type>::vector_size; + const vector_type a_vector{tmp}; + const vector_type b_vector{x}; + static_for<0, vector_size, 1>{}([&](auto idx) { + auto result = type_convert( + type_convert(a_vector.template AsType()[idx]) + + type_convert(b_vector.template AsType()[idx])); + this->template Set(i + idx, is_valid_element, result); + }); + } + } + else + { + this->template Set(i, is_valid_element, x + tmp); + } } } @@ -173,7 +203,7 @@ struct DynamicBuffer { constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store, t_per_x>( + amd_buffer_store, t_per_x, coherence>( x, p_data_, i, is_valid_element, element_space_size_); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && @@ -376,14 +406,19 @@ struct DynamicBuffer __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } }; -template +template __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) { - return DynamicBuffer{p, element_space_size}; + return DynamicBuffer{ + p, element_space_size}; } template < AddressSpaceEnum BufferAddressSpace, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, typename T, typename ElementSpaceSize, typename X, @@ -391,7 +426,7 @@ template < __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { - return DynamicBuffer{ + return DynamicBuffer{ p, element_space_size, invalid_element_value}; } diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 297434b0dddd5f1a680176c3bd099bedfda8dff4..c0a3c99f1fdafea9f151fe9fc319c2f7aaa0ffda 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1960667732646d04242e168485561b40e5417ab4 --- /dev/null +++ b/include/ck/utility/f8_utils.hpp @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +namespace ck { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class f8_rounding_mode +{ + standard, + stochastic +}; + +__host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +__device__ inline int clz(uint32_t x) { return __clz(x); } + +} // namespace ck + +namespace ck::utils { + +namespace { + +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; + + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; + + int exponent, bias; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; + + // convert to bitwise + using T_bitwise = typename NumericUtils::bitwise_type; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); + bias = NumericUtils::bias; + + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) + { + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later + } + } + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) + { + if(clip) + { + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(out_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; +} + +template +__host__ __device__ Y run_cast_from_f8(X x) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; + + // resulting type exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; + + // prepare the codes + constexpr X nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + // unpack the input + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; + + constexpr int exp_low_cutoff = + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return NaN; + } + else + { + if(x == nan_code) + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && !negative_zero_nan) + { + retval = x; + retval <<= 8; + return *(reinterpret_cast(&retval)); + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - in_mant); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= out_mant - in_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << out_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +} // namespace + +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) +{ + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); + + return run_cast_to_f8(x, rng); +} + +template +__host__ __device__ Y cast_from_f8(X x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + return run_cast_from_f8(x); +} + +} // namespace ck::utils diff --git a/include/ck/utility/functional.hpp b/include/ck/utility/functional.hpp index 08e730782f386cf5788c64bc04a008dc6cb37b28..91797d24092e3e32ad4a6bd40958952b124d9978 100644 --- a/include/ck/utility/functional.hpp +++ b/include/ck/utility/functional.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index 6f125ca4c944777d02ae1083334bec4602ad68c4..99c65f4eb85b67231557b1916fae10c6568e676d 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional3.hpp b/include/ck/utility/functional3.hpp index 06b67ef7e3fdec25a1ab2e18e23c904342233ebe..97605a7adeb8ae64e4e5a32debe9386295923b6c 100644 --- a/include/ck/utility/functional3.hpp +++ b/include/ck/utility/functional3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/functional4.hpp b/include/ck/utility/functional4.hpp index 6eeaf15c9b7ac283f7fcac96c195d28f179b6065..b5f3df8d7c517dfaf01320e41721da174883c2d9 100644 --- a/include/ck/utility/functional4.hpp +++ b/include/ck/utility/functional4.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 6a1ca966521710ecdb4ceaad9a83457031f4c508..98f40a4363aa2ddf2908fa497151f756b77d6f94 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 44ff438155d2a9c9ba9b0925d4e9fe95c1fa6bce..77564c6130baf45bfd331e69fc437fb7c7c96d18 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/get_shift.hpp b/include/ck/utility/get_shift.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a93081cfd01ce41cf305e4d69e127053de788cc --- /dev/null +++ b/include/ck/utility/get_shift.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/include/ck/utility/ignore.hpp b/include/ck/utility/ignore.hpp index ac33cbf9a508f5e6402d17d6285ae34a8196f6eb..f70a182fd4e5c6cc623acb2c65b5551b2a1acd14 100644 --- a/include/ck/utility/ignore.hpp +++ b/include/ck/utility/ignore.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b65640bfffed31e319549b7152f0c2a93bfaded9..65efaf388ae528feab9fcf06818e111244fa091a 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" +#include "type_convert.hpp" namespace ck { @@ -12,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c); template <> __device__ void inner_product(const float& a, const float& b, float& c) { -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) +#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) asm volatile("\n \ v_mac_f32 %0, %1, %2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); -#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) +#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) asm volatile("\n \ v_fmac_f32 %0, %1, %2 \n \ " @@ -71,26 +72,42 @@ inner_product(const float4_t& a, const float4_t& b, f c); } +template <> +__device__ void inner_product(const bhalf_t& a, const bhalf_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { #if defined(CK_USE_AMD_V_DOT2_F32_F16) -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop asm volatile("\n \ v_dot2_f32_f16 %0, %1, %2, %0\n \ + s_nop 2 \n \ " : "=v"(c) : "v"(a), "v"(b), "0"(c)); #else - c = __builtin_amdgcn_sdot2(a, b, c, false); + c = __builtin_amdgcn_fdot2(a, b, c, false); #endif #else const vector_type a_vector{a}; const vector_type b_vector{b}; static_for<0, 2, 1>{}([&](auto i) { - c += type_convert(a_vector.AsType()[i]) * - type_convert(b_vector.AsType()[i]); + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); }); #endif } @@ -162,15 +179,21 @@ __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { #if defined(CK_USE_AMD_V_DOT4_I32_I8) -#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ + s_nop 2 \n \ " : "=v"(c) : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); #else c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); #endif +#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11) + c = __builtin_amdgcn_sudot4(true, bit_cast(a), true, bit_cast(b), c, false); #else const vector_type a_vector{a}; const vector_type b_vector{b}; diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f079e2ca6490676d6c0e83e9e74dba205a8d6583 --- /dev/null +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "amd_gemm_dpp.hpp" +#include "data_type.hpp" +#include "type_convert.hpp" + +namespace ck { + +namespace dpp8 { + +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + +template +__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); + +// clang-format off +template <> +__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +// clang-format on + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly. + */ +template +__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c) +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + if constexpr(ShareA) + { + inline_v_dot2c_dpp8_instr(a, b, c); + } + else + { + inline_v_dot2c_dpp8_instr(b, a, c); + } +} + +/** + * DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast + * patters. + */ +constexpr std::array IntrinsicMaskDpp8 = { + 0, // 0, 0, 0, 0, 0, 0, 0, 0 + 2396745, // 1, 1, 1, 1, 1, 1, 1, 1 + 4793490, // 2, 2, 2, 2, 2, 2, 2, 2 + 7190235, // 3, 3, 3, 3, 3, 3, 3, 3 + 9586980, // 4, 4, 4, 4, 4, 4, 4, 4 + 11983725, // 5, 5, 5, 5, 5, 5, 5, 5 + 14380470, // 6, 6, 6, 6, 6, 6, 6, 6 + 16777215, // 7, 7, 7, 7, 7, 7, 7, 7 +}; + +/** + * Returns DPP8 sel modifier as an integer required for the intrinsic instruction. + */ +template +constexpr int get_dpp_sel_mask_broadcast() +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + return IntrinsicMaskDpp8[SrcLaneIdx]; +} + +template +__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c) +{ + constexpr int sel_mask = get_dpp_sel_mask_broadcast(); + const half2_t val_from_other_lane = + bit_cast(__builtin_amdgcn_mov_dpp8(bit_cast(a), sel_mask)); + c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false); +} + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics. + */ +template +__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c) +{ + if constexpr(ShareA) + { + intrinsic_fdot2_impl(a, b, c); + } + else + { + intrinsic_fdot2_impl(b, a, c); + } +} + +/** + * Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier. + * + * DPP modifier allows us to share one of the vectors between lanes in a lane group. + * When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same + * lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared. + * Note that all the threads in a lane group uses the same vector - broadcast pattern. + * + * `SrcLaneIdx` must be in range from 0 to 7. + */ +template +__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c) +{ +#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM + inline_v_dot2c_dpp8(a, b, c); +#else + intrinsic_fdot2(a, b, c); +#endif +} + +} // namespace dpp8 + +} // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 9aab4e24214a884bdfb391e0f9040bb3af2630ec..376070eb3d8ac326603b71e52e76949c168f4219 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7a324a6c458b3f1b8bb8037ccfac76e5eadceee0 --- /dev/null +++ b/include/ck/utility/is_detected.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template